This commit is contained in:
ymq1 2025-06-26 11:05:10 +08:00
parent 6430e59081
commit 4bd808415e
2 changed files with 3 additions and 28 deletions

View File

@ -96,6 +96,7 @@ class BaseChatLLM:
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}')
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
messages, add_generation_prompt=True, messages, add_generation_prompt=True,
tokenize=True, tokenize=True,

View File

@ -4,10 +4,10 @@
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.log import debug from appPublic.log import debug
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer
class Qwen3LLM(BaseChatLLM): class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
@ -17,9 +17,6 @@ class Qwen3LLM(BaseChatLLM):
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
@ -33,7 +30,7 @@ class Qwen3LLM(BaseChatLLM):
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}') debug(f'-----------{messages=}-----------')
text = self.tokenizer.apply_chat_template( text = self.tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
@ -43,26 +40,3 @@ class Qwen3LLM(BaseChatLLM):
return self.tokenizer([text], return_tensors="pt").to(self.model.device) return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("Qwen/Qwen3", Qwen3LLM) llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__':
import sys
model_path = sys.argv[1]
q3 = Qwen3LLM(model_path)
session = {}
while True:
print('input prompt')
p = input()
if p:
if p == 'q':
break;
for d in q3.stream_generate(session, p):
print(d)
"""
if not d['done']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n')
"""