This commit is contained in:
ymq1 2025-06-25 15:47:41 +08:00
parent 9bffe4b983
commit 6430e59081
5 changed files with 83 additions and 151 deletions

View File

@ -3,6 +3,7 @@ import asyncio
import json
import torch
from time import time
from aiostream import stream
from transformers import TextIteratorStreamer
from appPublic.log import debug
from appPublic.worker import awaitify
@ -26,18 +27,6 @@ class BaseChatLLM:
device = torch.device("mps")
self.model = self.model.to(device)
def get_session_key(self):
return self.model_id + ':messages'
def _get_session_messages(self, session):
key = self.get_session_key()
messages = session.get(key) or []
return messages
def _set_session_messages(self, session, messages):
key = self.get_session_key()
session[key] = messages
def get_streamer(self):
return TextIteratorStreamer(
tokenizer=self.tokenizer,
@ -60,15 +49,15 @@ class BaseChatLLM:
yield {
"id":id,
"object":"chat.completion.chunk",
"created":time(),
"created": t1,
"model":self.model_id,
"choices":[
{
"index":0,
"delta":{
"role": "assistant",
"content":txt
},
"logprobs":None,
"finish_reason":None
}
]
@ -80,7 +69,7 @@ class BaseChatLLM:
yield {
"id":id,
"object":"chat.completion.chunk",
"created":time(),
"created": t1,
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
@ -91,69 +80,11 @@ class BaseChatLLM:
"delta":{
"content":""
},
"logprobs":None,
"finish_reason":"stop"
}
]
}
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
if sys_prompt:
messages.append(self._build_sys_message(sys_prompt))
messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}')
all_txt = ''
for d in self._gen(messages):
if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(all_txt))
else:
all_txt += d['choices'][0]['delta']['content']
yield d
self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['choices'][0]['finish_reason'] == 'stop':
return d
def stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['choices'][0]['finish_reason'] == 'stop':
return d
async def async_stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]'
def build_kwargs(self, inputs, streamer):
generate_kwargs = dict(
**inputs,
@ -184,63 +115,70 @@ class BaseChatLLM:
d['input_tokens'] = input_len
yield d
class T2TChatLLM(BaseChatLLM):
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":prompt
}
async def async_gen(self, messages):
async for d in stream.iterate(self._gen(messages)):
yield d
def _build_sys_message(self, prompt):
return {
"role":"system",
"content": prompt
}
async def chat_completion_stream(self, messages):
async for d in self.async_gen(messages):
if d['choices'][0]['finish_reason']:
d['usage'] = {
'prompt_tokens': d['input_tokens'],
'completion_tokens': d['output_tokens'],
'total_tokens': d['input_tokens'] + d['output_tokens']
}
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]\n'
def _build_user_message(self, prompt, **kw):
return {
"role":"user",
"content": prompt
}
def reference(self, messages):
t1 = time()
inputs = self._messages2inputs(messages)
input_len = inputs["input_ids"].shape[-1]
streamer = self.get_streamer()
kwargs = self.build_kwargs(inputs, streamer)
thread = threading.Thread(target=self.model.generate,
kwargs=kwargs)
thread.start()
txt = ''
i = 0
for d in self.output_generator(streamer):
if i == 0:
i = 1
t1 = time()
if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content']
else:
i_tokens = d['input_tokens']
o_tokens = d['output_tokens']
class MMChatLLM(BaseChatLLM):
""" multiple modal chat LLM """
def _build_assistant_message(self, prompt):
t2 = time()
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None,
video_path=None, audio_path=None):
contents = [
{
"type":"text", "text": prompt
'id': f'chatcmpl-{getID()}',
"object":"chat.completion",
"created":t1,
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
"output_token": output_tokens,
"choices":[
{
"index":0,
"message":{
"role": "assistant",
"content": txt
},
"finish_reason":"stop"
}
],
"usage": {
"prompt_tokens": i_tokens,
"completion_tokens": o_tokens,
"total_tokens": i_tokens + o_tokens
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
}
async def chat_completion(self, messages):
f = awaitify(self.reference)
return await f(messages)

View File

@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Gemma3LLM(MMChatLLM):
class Gemma3LLM(BaseChatLLM):
def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"

View File

@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
model_id = "google/medgemma-4b-it"
class MedgemmaLLM(MMChatLLM):
class MedgemmaLLM(BaseChatLLM):
def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,

View File

@ -7,9 +7,9 @@ from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Qwen3LLM(T2TChatLLM):
class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(

View File

@ -21,23 +21,17 @@ def init():
rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw):
se = ServerEnv()
engine = se.engine
async def gor():
se = ServerEnv()
engine = se.engine
session = await get_session(request)
kwargs = {
}
if params_kw.image_path:
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
if params_kw.video_path:
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
async for d in engine.chat_completion_stream(params_kw.messages):
debug(f'{d=}')
yield d
return await stream_response(request, gor)
if params_kw.stream:
return await stream_response(request, gor)
else:
return await engine.chat_completion(params_kw.messages)
def main():
parser = argparse.ArgumentParser(prog="Sage")