This commit is contained in:
yumoqing 2025-06-14 08:23:37 +00:00
parent f3ce845388
commit 50a14e7c5c
6 changed files with 62 additions and 28 deletions

View File

@ -6,6 +6,15 @@ from time import time
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from appPublic.log import debug from appPublic.log import debug
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.uniqueID import getID
model_pathMap = {
}
def llm_register(model_path, Klass):
model_pathMap[model_path] = Klass
def get_llm_class(model_path):
return model_pathMap.get(model_path)
class BaseChatLLM: class BaseChatLLM:
def get_session_key(self): def get_session_key(self):
@ -31,6 +40,7 @@ class BaseChatLLM:
all_txt = '' all_txt = ''
t1 = time() t1 = time()
i = 0 i = 0
id = f'chatllm-{getID}'
for txt in streamer: for txt in streamer:
if txt == '': if txt == '':
continue continue
@ -39,22 +49,44 @@ class BaseChatLLM:
i += 1 i += 1
all_txt += txt all_txt += txt
yield { yield {
'done': False, "id":id,
'text': txt "object":"chat.completion.chunk",
"created":time.time(),
"model":self.model_id,
"choices":[
{
"index":0,
"delta":{
"content":txt
},
"logprobs":null,
"finish_reason":null
}
]
} }
t3 = time() t3 = time()
t = all_txt t = all_txt
unk = self.tokenizer(t, return_tensors="pt") unk = self.tokenizer(t, return_tensors="pt")
output_tokens = len(unk["input_ids"][0]) output_tokens = len(unk["input_ids"][0])
d = { yield {
'done': True, "id":id,
'text': all_txt, "object":"chat.completion.chunk",
'response_time': t2 - t1, "created":time.time(),
'finish_time': t3 - t1, "model":self.model_id,
'output_token': output_tokens "response_time": t2 - t1,
"finish_time": t3 - t1,
"output_token": output_tokens,
"choices":[
{
"index":0,
"delta":{
"content":""
},
"logprobs":null,
"finish_reason":"stop"
}
]
} }
# debug(f'{all_txt=}, {d=}')
yield d
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session) messages = self._get_session_messages(session)
@ -63,7 +95,7 @@ class BaseChatLLM:
messages.append(self._build_user_message(prompt, image_path=image_path)) messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}') # debug(f'{messages=}')
for d in self._gen(messages): for d in self._gen(messages):
if d['done']: if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(d['text'])) messages.append(self._build_assistant_message(d['text']))
yield d yield d
self._set_session_messages(session, messages) self._set_session_messages(session, messages)
@ -79,7 +111,7 @@ class BaseChatLLM:
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['done']: if d['choices'][0]['finish_reason'] == 'stop':
return d return d
def stream_generate(self, session, prompt, def stream_generate(self, session, prompt,
image_path=None, image_path=None,
@ -87,7 +119,7 @@ class BaseChatLLM:
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n' s = f'data: {json.dumps(d)}\n'
yield s yield s
async def async_generate(self, session, prompt, async def async_generate(self, session, prompt,
@ -97,7 +129,7 @@ class BaseChatLLM:
sys_prompt=None): sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0) await asyncio.sleep(0)
if d['done']: if d['choices'][0]['finish_reason'] == 'stop':
return d return d
async def async_stream_generate(self, session, prompt, async def async_stream_generate(self, session, prompt,
@ -106,8 +138,9 @@ class BaseChatLLM:
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n' s = f'data: {json.dumps(d)}\n'
yield s yield s
yield 'data: [done]'
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
@ -135,9 +168,8 @@ class BaseChatLLM:
kwargs=kwargs) kwargs=kwargs)
thread.start() thread.start()
for d in self.output_generator(streamer): for d in self.output_generator(streamer):
if d['done']: if d['choices'][0]['finish_reason'] == 'stop':
d['input_tokens'] = input_len d['input_tokens'] = input_len
# debug(f'{d=}\n')
yield d yield d
class T2TChatLLM(BaseChatLLM): class T2TChatLLM(BaseChatLLM):

View File

@ -9,7 +9,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM from llmengine.base_chat_llm import MMChatLLM, llm_register
class Gemma3LLM(MMChatLLM): class Gemma3LLM(MMChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
@ -21,6 +21,8 @@ class Gemma3LLM(MMChatLLM):
self.messages = [] self.messages = []
self.model_id = model_id self.model_id = model_id
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM)
if __name__ == '__main__': if __name__ == '__main__':
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
session = {} session = {}

View File

@ -4,7 +4,7 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM from llmengine.base_chat_llm import MMChatLLM, llm_register
model_id = "google/medgemma-4b-it" model_id = "google/medgemma-4b-it"
@ -29,6 +29,8 @@ class MedgemmaLLM(MMChatLLM):
).to(self.model.device, dtype=torch.bfloat16) ).to(self.model.device, dtype=torch.bfloat16)
return inputs return inputs
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM)
if __name__ == '__main__': if __name__ == '__main__':
med = MedgemmaLLM('/share/models/google/medgemma-4b-it') med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
session = {} session = {}

View File

@ -7,7 +7,7 @@ from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
class Qwen3LLM(T2TChatLLM): class Qwen3LLM(T2TChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
@ -39,6 +39,8 @@ class Qwen3LLM(T2TChatLLM):
) )
return self.tokenizer([text], return_tensors="pt").to(self.model.device) return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM)
if __name__ == '__main__': if __name__ == '__main__':
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B') q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
session = {} session = {}

View File

@ -2,9 +2,11 @@ from traceback import format_exc
import os import os
import sys import sys
import argparse import argparse
from llmengine.base_chat_llm import get_llm_class
from llmengine.gemma3_it import Gemma3LLM from llmengine.gemma3_it import Gemma3LLM
from llmengine.qwen3 import Qwen3LLM from llmengine.qwen3 import Qwen3LLM
from llmengine.medgemma3_it import MedgemmaLLM from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.devstral import DevstralLLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug from appPublic.log import debug
@ -13,13 +15,6 @@ from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver from ahserver.webapp import webserver
from aiohttp_session import get_session from aiohttp_session import get_session
model_pathMap = {
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
"/share/models/google/medgemma-4b-it": MedgemmaLLM
}
def register(model_path, Klass):
model_pathMap[model_path] = Klass
def init(): def init():
rf = RegisterFunction() rf = RegisterFunction()
@ -51,7 +46,7 @@ def main():
parser.add_argument('-p', '--port') parser.add_argument('-p', '--port')
parser.add_argument('model_path') parser.add_argument('model_path')
args = parser.parse_args() args = parser.parse_args()
Klass = model_pathMap.get(args.model_path) Klass = get_llm_class(args.model_path)
if Klass is None: if Klass is None:
e = Exception(f'{model_path} has not mapping to a model class') e = Exception(f'{model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}') exception(f'{e}, {format_exc()}')

View File

@ -9,6 +9,7 @@ license = {text = "MIT"}
dependencies = [ dependencies = [
"torch", "torch",
"transformers", "transformers",
"mistral-common",
"accelerate" "accelerate"
] ]