This commit is contained in:
ymq1 2025-06-25 15:47:41 +08:00
parent 9bffe4b983
commit 6430e59081
5 changed files with 83 additions and 151 deletions

View File

@ -3,6 +3,7 @@ import asyncio
import json import json
import torch import torch
from time import time from time import time
from aiostream import stream
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from appPublic.log import debug from appPublic.log import debug
from appPublic.worker import awaitify from appPublic.worker import awaitify
@ -26,18 +27,6 @@ class BaseChatLLM:
device = torch.device("mps") device = torch.device("mps")
self.model = self.model.to(device) self.model = self.model.to(device)
def get_session_key(self):
return self.model_id + ':messages'
def _get_session_messages(self, session):
key = self.get_session_key()
messages = session.get(key) or []
return messages
def _set_session_messages(self, session, messages):
key = self.get_session_key()
session[key] = messages
def get_streamer(self): def get_streamer(self):
return TextIteratorStreamer( return TextIteratorStreamer(
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
@ -60,15 +49,15 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time(), "created": t1,
"model":self.model_id, "model":self.model_id,
"choices":[ "choices":[
{ {
"index":0, "index":0,
"delta":{ "delta":{
"role": "assistant",
"content":txt "content":txt
}, },
"logprobs":None,
"finish_reason":None "finish_reason":None
} }
] ]
@ -80,7 +69,7 @@ class BaseChatLLM:
yield { yield {
"id":id, "id":id,
"object":"chat.completion.chunk", "object":"chat.completion.chunk",
"created":time(), "created": t1,
"model":self.model_id, "model":self.model_id,
"response_time": t2 - t1, "response_time": t2 - t1,
"finish_time": t3 - t1, "finish_time": t3 - t1,
@ -91,69 +80,11 @@ class BaseChatLLM:
"delta":{ "delta":{
"content":"" "content":""
}, },
"logprobs":None,
"finish_reason":"stop" "finish_reason":"stop"
} }
] ]
} }
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
if sys_prompt:
messages.append(self._build_sys_message(sys_prompt))
messages.append(self._build_user_message(prompt, image_path=image_path))
# debug(f'{messages=}')
all_txt = ''
for d in self._gen(messages):
if d['choices'][0]['finish_reason'] == 'stop':
messages.append(self._build_assistant_message(all_txt))
else:
all_txt += d['choices'][0]['delta']['content']
yield d
self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['choices'][0]['finish_reason'] == 'stop':
return d
def stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['choices'][0]['finish_reason'] == 'stop':
return d
async def async_stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]'
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,
@ -184,63 +115,70 @@ class BaseChatLLM:
d['input_tokens'] = input_len d['input_tokens'] = input_len
yield d yield d
class T2TChatLLM(BaseChatLLM): async def async_gen(self, messages):
def _build_assistant_message(self, prompt): async for d in stream.iterate(self._gen(messages)):
return { yield d
"role":"assistant",
"content":prompt
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content": prompt
}
def _build_user_message(self, prompt, **kw):
return {
"role":"user",
"content": prompt
}
class MMChatLLM(BaseChatLLM):
""" multiple modal chat LLM """
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None,
video_path=None, audio_path=None):
contents = [
{
"type":"text", "text": prompt
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
}
async def chat_completion_stream(self, messages):
async for d in self.async_gen(messages):
if d['choices'][0]['finish_reason']:
d['usage'] = {
'prompt_tokens': d['input_tokens'],
'completion_tokens': d['output_tokens'],
'total_tokens': d['input_tokens'] + d['output_tokens']
}
s = f'data: {json.dumps(d)}\n'
yield s
yield 'data: [DONE]\n'
def reference(self, messages):
t1 = time()
inputs = self._messages2inputs(messages)
input_len = inputs["input_ids"].shape[-1]
streamer = self.get_streamer()
kwargs = self.build_kwargs(inputs, streamer)
thread = threading.Thread(target=self.model.generate,
kwargs=kwargs)
thread.start()
txt = ''
i = 0
for d in self.output_generator(streamer):
if i == 0:
i = 1
t1 = time()
if d['choices'][0]['finish_reason'] != 'stop':
txt += d['choices'][0]['delta']['content']
else:
i_tokens = d['input_tokens']
o_tokens = d['output_tokens']
t2 = time()
return {
'id': f'chatcmpl-{getID()}',
"object":"chat.completion",
"created":t1,
"model":self.model_id,
"response_time": t2 - t1,
"finish_time": t3 - t1,
"output_token": output_tokens,
"choices":[
{
"index":0,
"message":{
"role": "assistant",
"content": txt
},
"finish_reason":"stop"
}
],
"usage": {
"prompt_tokens": i_tokens,
"completion_tokens": o_tokens,
"total_tokens": i_tokens + o_tokens
}
}
async def chat_completion(self, messages):
f = awaitify(self.reference)
return await f(messages)

View File

@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Gemma3LLM(MMChatLLM): class Gemma3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained( self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto" model_id, device_map="auto"

View File

@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import MMChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
model_id = "google/medgemma-4b-it" model_id = "google/medgemma-4b-it"
class MedgemmaLLM(MMChatLLM): class MedgemmaLLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained( self.model = AutoModelForImageTextToText.from_pretrained(
model_id, model_id,

View File

@ -7,9 +7,9 @@ from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
class Qwen3LLM(T2TChatLLM): class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(

View File

@ -21,23 +21,17 @@ def init():
rf.register('chat_completions', chat_completions) rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw): async def chat_completions(request, params_kw, *params, **kw):
se = ServerEnv()
engine = se.engine
async def gor(): async def gor():
se = ServerEnv() async for d in engine.chat_completion_stream(params_kw.messages):
engine = se.engine
session = await get_session(request)
kwargs = {
}
if params_kw.image_path:
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
if params_kw.video_path:
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
debug(f'{d=}') debug(f'{d=}')
yield d yield d
return await stream_response(request, gor) if params_kw.stream:
return await stream_response(request, gor)
else:
return await engine.chat_completion(params_kw.messages)
def main(): def main():
parser = argparse.ArgumentParser(prog="Sage") parser = argparse.ArgumentParser(prog="Sage")