ugfix
This commit is contained in:
parent
9bffe4b983
commit
6430e59081
@ -3,6 +3,7 @@ import asyncio
|
||||
import json
|
||||
import torch
|
||||
from time import time
|
||||
from aiostream import stream
|
||||
from transformers import TextIteratorStreamer
|
||||
from appPublic.log import debug
|
||||
from appPublic.worker import awaitify
|
||||
@ -26,18 +27,6 @@ class BaseChatLLM:
|
||||
device = torch.device("mps")
|
||||
self.model = self.model.to(device)
|
||||
|
||||
def get_session_key(self):
|
||||
return self.model_id + ':messages'
|
||||
|
||||
def _get_session_messages(self, session):
|
||||
key = self.get_session_key()
|
||||
messages = session.get(key) or []
|
||||
return messages
|
||||
|
||||
def _set_session_messages(self, session, messages):
|
||||
key = self.get_session_key()
|
||||
session[key] = messages
|
||||
|
||||
def get_streamer(self):
|
||||
return TextIteratorStreamer(
|
||||
tokenizer=self.tokenizer,
|
||||
@ -60,15 +49,15 @@ class BaseChatLLM:
|
||||
yield {
|
||||
"id":id,
|
||||
"object":"chat.completion.chunk",
|
||||
"created":time(),
|
||||
"created": t1,
|
||||
"model":self.model_id,
|
||||
"choices":[
|
||||
{
|
||||
"index":0,
|
||||
"delta":{
|
||||
"role": "assistant",
|
||||
"content":txt
|
||||
},
|
||||
"logprobs":None,
|
||||
"finish_reason":None
|
||||
}
|
||||
]
|
||||
@ -80,7 +69,7 @@ class BaseChatLLM:
|
||||
yield {
|
||||
"id":id,
|
||||
"object":"chat.completion.chunk",
|
||||
"created":time(),
|
||||
"created": t1,
|
||||
"model":self.model_id,
|
||||
"response_time": t2 - t1,
|
||||
"finish_time": t3 - t1,
|
||||
@ -91,69 +80,11 @@ class BaseChatLLM:
|
||||
"delta":{
|
||||
"content":""
|
||||
},
|
||||
"logprobs":None,
|
||||
"finish_reason":"stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
messages = self._get_session_messages(session)
|
||||
if sys_prompt:
|
||||
messages.append(self._build_sys_message(sys_prompt))
|
||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||
# debug(f'{messages=}')
|
||||
all_txt = ''
|
||||
for d in self._gen(messages):
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
messages.append(self._build_assistant_message(all_txt))
|
||||
else:
|
||||
all_txt += d['choices'][0]['delta']['content']
|
||||
yield d
|
||||
self._set_session_messages(session, messages)
|
||||
|
||||
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
await asyncio.sleep(0)
|
||||
yield d
|
||||
|
||||
def generate(self, session, prompt,
|
||||
image_path=None,
|
||||
video_path=None,
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
return d
|
||||
def stream_generate(self, session, prompt,
|
||||
image_path=None,
|
||||
video_path=None,
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
s = f'data: {json.dumps(d)}\n'
|
||||
yield s
|
||||
|
||||
async def async_generate(self, session, prompt,
|
||||
image_path=None,
|
||||
video_path=None,
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
await asyncio.sleep(0)
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
return d
|
||||
|
||||
async def async_stream_generate(self, session, prompt,
|
||||
image_path=None,
|
||||
video_path=None,
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
s = f'data: {json.dumps(d)}\n'
|
||||
yield s
|
||||
yield 'data: [DONE]'
|
||||
|
||||
def build_kwargs(self, inputs, streamer):
|
||||
generate_kwargs = dict(
|
||||
**inputs,
|
||||
@ -184,63 +115,70 @@ class BaseChatLLM:
|
||||
d['input_tokens'] = input_len
|
||||
yield d
|
||||
|
||||
class T2TChatLLM(BaseChatLLM):
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":prompt
|
||||
}
|
||||
async def async_gen(self, messages):
|
||||
async for d in stream.iterate(self._gen(messages)):
|
||||
yield d
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content": prompt
|
||||
}
|
||||
async def chat_completion_stream(self, messages):
|
||||
async for d in self.async_gen(messages):
|
||||
if d['choices'][0]['finish_reason']:
|
||||
d['usage'] = {
|
||||
'prompt_tokens': d['input_tokens'],
|
||||
'completion_tokens': d['output_tokens'],
|
||||
'total_tokens': d['input_tokens'] + d['output_tokens']
|
||||
}
|
||||
s = f'data: {json.dumps(d)}\n'
|
||||
yield s
|
||||
yield 'data: [DONE]\n'
|
||||
|
||||
def _build_user_message(self, prompt, **kw):
|
||||
return {
|
||||
"role":"user",
|
||||
"content": prompt
|
||||
}
|
||||
def reference(self, messages):
|
||||
t1 = time()
|
||||
inputs = self._messages2inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
streamer = self.get_streamer()
|
||||
kwargs = self.build_kwargs(inputs, streamer)
|
||||
thread = threading.Thread(target=self.model.generate,
|
||||
kwargs=kwargs)
|
||||
thread.start()
|
||||
txt = ''
|
||||
i = 0
|
||||
for d in self.output_generator(streamer):
|
||||
if i == 0:
|
||||
i = 1
|
||||
t1 = time()
|
||||
if d['choices'][0]['finish_reason'] != 'stop':
|
||||
txt += d['choices'][0]['delta']['content']
|
||||
else:
|
||||
i_tokens = d['input_tokens']
|
||||
o_tokens = d['output_tokens']
|
||||
|
||||
class MMChatLLM(BaseChatLLM):
|
||||
""" multiple modal chat LLM """
|
||||
def _build_assistant_message(self, prompt):
|
||||
t2 = time()
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, image_path=None,
|
||||
video_path=None, audio_path=None):
|
||||
contents = [
|
||||
{
|
||||
"type":"text", "text": prompt
|
||||
'id': f'chatcmpl-{getID()}',
|
||||
"object":"chat.completion",
|
||||
"created":t1,
|
||||
"model":self.model_id,
|
||||
"response_time": t2 - t1,
|
||||
"finish_time": t3 - t1,
|
||||
"output_token": output_tokens,
|
||||
"choices":[
|
||||
{
|
||||
"index":0,
|
||||
"message":{
|
||||
"role": "assistant",
|
||||
"content": txt
|
||||
},
|
||||
"finish_reason":"stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": i_tokens,
|
||||
"completion_tokens": o_tokens,
|
||||
"total_tokens": i_tokens + o_tokens
|
||||
}
|
||||
]
|
||||
if image_path:
|
||||
contents.append({
|
||||
"type": "image",
|
||||
"image": image_path
|
||||
})
|
||||
if video_path:
|
||||
contents.append({
|
||||
"type": "video",
|
||||
"video":video_path
|
||||
})
|
||||
if audio_path:
|
||||
contents.append({
|
||||
"tyoe": "audio",
|
||||
"audio": audio_path
|
||||
})
|
||||
return {
|
||||
"role": "user",
|
||||
"content": contents
|
||||
}
|
||||
|
||||
async def chat_completion(self, messages):
|
||||
f = awaitify(self.reference)
|
||||
return await f(messages)
|
||||
|
||||
|
@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||
|
||||
class Gemma3LLM(MMChatLLM):
|
||||
class Gemma3LLM(BaseChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
|
@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||
|
||||
model_id = "google/medgemma-4b-it"
|
||||
|
||||
class MedgemmaLLM(MMChatLLM):
|
||||
class MedgemmaLLM(BaseChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
|
@ -7,9 +7,9 @@ from ahserver.serverenv import get_serverenv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from PIL import Image
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
||||
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||
|
||||
class Qwen3LLM(T2TChatLLM):
|
||||
class Qwen3LLM(BaseChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
|
@ -21,23 +21,17 @@ def init():
|
||||
rf.register('chat_completions', chat_completions)
|
||||
|
||||
async def chat_completions(request, params_kw, *params, **kw):
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
async def gor():
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
session = await get_session(request)
|
||||
kwargs = {
|
||||
}
|
||||
if params_kw.image_path:
|
||||
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
|
||||
if params_kw.video_path:
|
||||
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
|
||||
if params_kw.audio_path:
|
||||
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
||||
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
||||
async for d in engine.chat_completion_stream(params_kw.messages):
|
||||
debug(f'{d=}')
|
||||
yield d
|
||||
|
||||
return await stream_response(request, gor)
|
||||
if params_kw.stream:
|
||||
return await stream_response(request, gor)
|
||||
else:
|
||||
return await engine.chat_completion(params_kw.messages)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="Sage")
|
||||
|
Loading…
Reference in New Issue
Block a user