diff --git a/llmengine/chatllm.py b/llmengine/chatllm.py index cc1c7e1..769cb51 100644 --- a/llmengine/chatllm.py +++ b/llmengine/chatllm.py @@ -62,3 +62,39 @@ class TransformersChatEngine: for new_text in streamer: yield new_text + +if __name__ == '__main__': + import os + import sys + import argparse + def parse_args(): + parser = argparse.ArgumentParser(description="Transformers Chat CLI") + parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称") + parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量") + parser.add_argument("--stream", action="store_true", help="是否流式输出") + return parser.parse_args() + + def generate(engine, stream): + while True: + print('prompt("q" to exit):') + p = input() + if p == 'q': + break + if not p: + continue + if stream: + for token in engine.stream_generate(p): + print(token, end="", flush=True) + else: + print(engine.generate(p)) + + def main(): + args = parse_args() + print(f'{args=}') + engine = TransformersChatEngine( + model_name=args.model, + gpus=args.gpus + ) + generate(engine, args.stream) + + main() diff --git a/pyproject.toml b/pyproject.toml index de8b36a..3a428f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,8 @@ requires-python = ">=3.8" license = {text = "MIT"} dependencies = [ "torch", - "tramsformers", - "acelerate" + "transformers", + "accelerate" ] [project.optional-dependencies] diff --git a/test/chatllm b/test/chatllm new file mode 100755 index 0000000..00a1526 --- /dev/null +++ b/test/chatllm @@ -0,0 +1,25 @@ +#!/share/vllm-0.8.5/bin/python +import os +import sys +import argparse + +def get_args(): + parser = argparse.ArgumentParser(description="Example script using argparse") + parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split') + parser.add_argument('modelpath', type=str, help='Path to model folder') + args = parser.parse_args() + return args + +def main(): + args = get_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + gpus = args.gpus.split(',') + cnt=len(gpus) + cmdline = f'/share/vllm-0.8.5/bin/python -m llmengine.chatllm --model {args.modelpath} --gpus {cnt}' + print(args, cmdline) + os.system(cmdline) + +if __name__ == '__main__': + main() +