bugfix
This commit is contained in:
parent
6430e59081
commit
4bd808415e
@ -96,6 +96,7 @@ class BaseChatLLM:
|
||||
return generate_kwargs
|
||||
|
||||
def _messages2inputs(self, messages):
|
||||
debug(f'{messages=}')
|
||||
return self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
|
@ -4,10 +4,10 @@
|
||||
from appPublic.worker import awaitify
|
||||
from appPublic.log import debug
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from PIL import Image
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
class Qwen3LLM(BaseChatLLM):
|
||||
def __init__(self, model_id):
|
||||
@ -17,9 +17,6 @@ class Qwen3LLM(BaseChatLLM):
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
self.model = self.model.to(device)
|
||||
self.model_id = model_id
|
||||
|
||||
def build_kwargs(self, inputs, streamer):
|
||||
@ -33,7 +30,7 @@ class Qwen3LLM(BaseChatLLM):
|
||||
return generate_kwargs
|
||||
|
||||
def _messages2inputs(self, messages):
|
||||
debug(f'{messages=}')
|
||||
debug(f'-----------{messages=}-----------')
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
@ -43,26 +40,3 @@ class Qwen3LLM(BaseChatLLM):
|
||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||
|
||||
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
model_path = sys.argv[1]
|
||||
q3 = Qwen3LLM(model_path)
|
||||
session = {}
|
||||
while True:
|
||||
print('input prompt')
|
||||
p = input()
|
||||
if p:
|
||||
if p == 'q':
|
||||
break;
|
||||
for d in q3.stream_generate(session, p):
|
||||
print(d)
|
||||
"""
|
||||
if not d['done']:
|
||||
print(d['text'], end='', flush=True)
|
||||
else:
|
||||
x = {k:v for k,v in d.items() if k != 'text'}
|
||||
print(f'\n{x}\n')
|
||||
"""
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user