bugfix
This commit is contained in:
parent
6bff0bd0ec
commit
089bc1af1f
@ -62,3 +62,39 @@ class TransformersChatEngine:
|
||||
for new_text in streamer:
|
||||
yield new_text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Transformers Chat CLI")
|
||||
parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称")
|
||||
parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量")
|
||||
parser.add_argument("--stream", action="store_true", help="是否流式输出")
|
||||
return parser.parse_args()
|
||||
|
||||
def generate(engine, stream):
|
||||
while True:
|
||||
print('prompt("q" to exit):')
|
||||
p = input()
|
||||
if p == 'q':
|
||||
break
|
||||
if not p:
|
||||
continue
|
||||
if stream:
|
||||
for token in engine.stream_generate(p):
|
||||
print(token, end="", flush=True)
|
||||
else:
|
||||
print(engine.generate(p))
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(f'{args=}')
|
||||
engine = TransformersChatEngine(
|
||||
model_name=args.model,
|
||||
gpus=args.gpus
|
||||
)
|
||||
generate(engine, args.stream)
|
||||
|
||||
main()
|
||||
|
@ -8,8 +8,8 @@ requires-python = ">=3.8"
|
||||
license = {text = "MIT"}
|
||||
dependencies = [
|
||||
"torch",
|
||||
"tramsformers",
|
||||
"acelerate"
|
||||
"transformers",
|
||||
"accelerate"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
25
test/chatllm
Executable file
25
test/chatllm
Executable file
@ -0,0 +1,25 @@
|
||||
#!/share/vllm-0.8.5/bin/python
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Example script using argparse")
|
||||
parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split')
|
||||
parser.add_argument('modelpath', type=str, help='Path to model folder')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||
gpus = args.gpus.split(',')
|
||||
cnt=len(gpus)
|
||||
cmdline = f'/share/vllm-0.8.5/bin/python -m llmengine.chatllm --model {args.modelpath} --gpus {cnt}'
|
||||
print(args, cmdline)
|
||||
os.system(cmdline)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user