This commit is contained in:
yumoqing 2025-05-30 03:42:59 +00:00
parent 6bff0bd0ec
commit 089bc1af1f
3 changed files with 63 additions and 2 deletions

View File

@ -62,3 +62,39 @@ class TransformersChatEngine:
for new_text in streamer:
yield new_text
if __name__ == '__main__':
import os
import sys
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Transformers Chat CLI")
parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称")
parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量")
parser.add_argument("--stream", action="store_true", help="是否流式输出")
return parser.parse_args()
def generate(engine, stream):
while True:
print('prompt("q" to exit):')
p = input()
if p == 'q':
break
if not p:
continue
if stream:
for token in engine.stream_generate(p):
print(token, end="", flush=True)
else:
print(engine.generate(p))
def main():
args = parse_args()
print(f'{args=}')
engine = TransformersChatEngine(
model_name=args.model,
gpus=args.gpus
)
generate(engine, args.stream)
main()

View File

@ -8,8 +8,8 @@ requires-python = ">=3.8"
license = {text = "MIT"}
dependencies = [
"torch",
"tramsformers",
"acelerate"
"transformers",
"accelerate"
]
[project.optional-dependencies]

25
test/chatllm Executable file
View File

@ -0,0 +1,25 @@
#!/share/vllm-0.8.5/bin/python
import os
import sys
import argparse
def get_args():
parser = argparse.ArgumentParser(description="Example script using argparse")
parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split')
parser.add_argument('modelpath', type=str, help='Path to model folder')
args = parser.parse_args()
return args
def main():
args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
gpus = args.gpus.split(',')
cnt=len(gpus)
cmdline = f'/share/vllm-0.8.5/bin/python -m llmengine.chatllm --model {args.modelpath} --gpus {cnt}'
print(args, cmdline)
os.system(cmdline)
if __name__ == '__main__':
main()