bugfix
This commit is contained in:
parent
6bff0bd0ec
commit
089bc1af1f
@ -62,3 +62,39 @@ class TransformersChatEngine:
|
|||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
yield new_text
|
yield new_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Transformers Chat CLI")
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称")
|
||||||
|
parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量")
|
||||||
|
parser.add_argument("--stream", action="store_true", help="是否流式输出")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def generate(engine, stream):
|
||||||
|
while True:
|
||||||
|
print('prompt("q" to exit):')
|
||||||
|
p = input()
|
||||||
|
if p == 'q':
|
||||||
|
break
|
||||||
|
if not p:
|
||||||
|
continue
|
||||||
|
if stream:
|
||||||
|
for token in engine.stream_generate(p):
|
||||||
|
print(token, end="", flush=True)
|
||||||
|
else:
|
||||||
|
print(engine.generate(p))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
print(f'{args=}')
|
||||||
|
engine = TransformersChatEngine(
|
||||||
|
model_name=args.model,
|
||||||
|
gpus=args.gpus
|
||||||
|
)
|
||||||
|
generate(engine, args.stream)
|
||||||
|
|
||||||
|
main()
|
||||||
|
@ -8,8 +8,8 @@ requires-python = ">=3.8"
|
|||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
"torch",
|
||||||
"tramsformers",
|
"transformers",
|
||||||
"acelerate"
|
"accelerate"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
25
test/chatllm
Executable file
25
test/chatllm
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Example script using argparse")
|
||||||
|
parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split')
|
||||||
|
parser.add_argument('modelpath', type=str, help='Path to model folder')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
|
||||||
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||||
|
gpus = args.gpus.split(',')
|
||||||
|
cnt=len(gpus)
|
||||||
|
cmdline = f'/share/vllm-0.8.5/bin/python -m llmengine.chatllm --model {args.modelpath} --gpus {cnt}'
|
||||||
|
print(args, cmdline)
|
||||||
|
os.system(cmdline)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user