This commit is contained in:
yumoqing 2025-06-14 08:23:37 +00:00
parent f3ce845388
commit 50a14e7c5c
6 changed files with 62 additions and 28 deletions

View File

@ -6,6 +6,15 @@ from time import time
from transformers import TextIteratorStreamer
from appPublic.log import debug
from appPublic.worker import awaitify
from appPublic.uniqueID import getID
model_pathMap = {
}
def llm_register(model_path, Klass):
model_pathMap[model_path] = Klass
def get_llm_class(model_path):
return model_pathMap.get(model_path)
class BaseChatLLM:
def get_session_key(self):
@ -31,6 +40,7 @@ class BaseChatLLM:
all_txt = ''
t1 = time()
i = 0
id = f'chatllm-{getID}'
for txt in streamer:
if txt == '':
continue
@ -39,22 +49,44 @@ class BaseChatLLM:
i += 1
all_txt += txt
yield {
'done': False,
'text': txt
"id":id,
"object":"chat.completion.chunk",
"created":time.time(),
"model":self.model_id,
"choices":[
{
"index":0,
"delta":{
"content":txt
},
"logprobs":null,
"finish_reason":null
}
]
}
t3 = time()
t = all_txt
unk = self.tokenizer(t, return_tensors="pt")
output_tokens = len(unk["input_ids"][0])
d = {
'done': True,
'text': all_txt,
'response_time': t2 - t1,
'finish_time': t3 - t1,
'output_token': output_tokens
yield {
"id":id,
"object":"chat.completion.chunk",
"created":time.time(),
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
"output_token": output_tokens,
"choices":[
{
"index":0,
"delta":{
"content":""
},
"logprobs":null,
"finish_reason":"stop"
}
]
}
# debug(f'{all_txt=}, {d=}')
yield d
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
@ -63,7 +95,7 @@ class BaseChatLLM:
messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}')
for d in self._gen(messages):
if d['done']:
if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(d['text']))
yield d
self._set_session_messages(session, messages)
@ -79,7 +111,7 @@ class BaseChatLLM:
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['done']:
if d['choices'][0]['finish_reason'] == 'stop':
return d
def stream_generate(self, session, prompt,
image_path=None,
@ -87,7 +119,7 @@ class BaseChatLLM:
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n'
s = f'data: {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
@ -97,7 +129,7 @@ class BaseChatLLM:
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['done']:
if d['choices'][0]['finish_reason'] == 'stop':
return d
async def async_stream_generate(self, session, prompt,
@ -106,8 +138,9 @@ class BaseChatLLM:
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n'
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [done]'
def build_kwargs(self, inputs, streamer):
generate_kwargs = dict(
@ -135,9 +168,8 @@ class BaseChatLLM:
kwargs=kwargs)
thread.start()
for d in self.output_generator(streamer):
if d['done']:
if d['choices'][0]['finish_reason'] == 'stop':
d['input_tokens'] = input_len
# debug(f'{d=}\n')
yield d
class T2TChatLLM(BaseChatLLM):

View File

@ -9,7 +9,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM
from llmengine.base_chat_llm import MMChatLLM, llm_register
class Gemma3LLM(MMChatLLM):
def __init__(self, model_id):
@ -21,6 +21,8 @@ class Gemma3LLM(MMChatLLM):
self.messages = []
self.model_id = model_id
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM)
if __name__ == '__main__':
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
session = {}

View File

@ -4,7 +4,7 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM
from llmengine.base_chat_llm import MMChatLLM, llm_register
model_id = "google/medgemma-4b-it"
@ -29,6 +29,8 @@ class MedgemmaLLM(MMChatLLM):
).to(self.model.device, dtype=torch.bfloat16)
return inputs
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM)
if __name__ == '__main__':
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
session = {}

View File

@ -7,7 +7,7 @@ from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
class Qwen3LLM(T2TChatLLM):
def __init__(self, model_id):
@ -39,6 +39,8 @@ class Qwen3LLM(T2TChatLLM):
)
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM)
if __name__ == '__main__':
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
session = {}

View File

@ -2,9 +2,11 @@ from traceback import format_exc
import os
import sys
import argparse
from llmengine.base_chat_llm import get_llm_class
from llmengine.gemma3_it import Gemma3LLM
from llmengine.qwen3 import Qwen3LLM
from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.devstral import DevstralLLM
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug
@ -13,13 +15,6 @@ from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
from aiohttp_session import get_session
model_pathMap = {
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
"/share/models/google/medgemma-4b-it": MedgemmaLLM
}
def register(model_path, Klass):
model_pathMap[model_path] = Klass
def init():
rf = RegisterFunction()
@ -51,7 +46,7 @@ def main():
parser.add_argument('-p', '--port')
parser.add_argument('model_path')
args = parser.parse_args()
Klass = model_pathMap.get(args.model_path)
Klass = get_llm_class(args.model_path)
if Klass is None:
e = Exception(f'{model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}')

View File

@ -9,6 +9,7 @@ license = {text = "MIT"}
dependencies = [
"torch",
"transformers",
"mistral-common",
"accelerate"
]