diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index cb03117..26b3f59 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -3,6 +3,7 @@ import asyncio import json import torch from time import time +from aiostream import stream from transformers import TextIteratorStreamer from appPublic.log import debug from appPublic.worker import awaitify @@ -26,18 +27,6 @@ class BaseChatLLM: device = torch.device("mps") self.model = self.model.to(device) - def get_session_key(self): - return self.model_id + ':messages' - - def _get_session_messages(self, session): - key = self.get_session_key() - messages = session.get(key) or [] - return messages - - def _set_session_messages(self, session, messages): - key = self.get_session_key() - session[key] = messages - def get_streamer(self): return TextIteratorStreamer( tokenizer=self.tokenizer, @@ -60,15 +49,15 @@ class BaseChatLLM: yield { "id":id, "object":"chat.completion.chunk", - "created":time(), + "created": t1, "model":self.model_id, "choices":[ { "index":0, "delta":{ + "role": "assistant", "content":txt }, - "logprobs":None, "finish_reason":None } ] @@ -80,7 +69,7 @@ class BaseChatLLM: yield { "id":id, "object":"chat.completion.chunk", - "created":time(), + "created": t1, "model":self.model_id, "response_time": t2 - t1, "finish_time": t3 - t1, @@ -91,69 +80,11 @@ class BaseChatLLM: "delta":{ "content":"" }, - "logprobs":None, "finish_reason":"stop" } ] } - def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): - messages = self._get_session_messages(session) - if sys_prompt: - messages.append(self._build_sys_message(sys_prompt)) - messages.append(self._build_user_message(prompt, image_path=image_path)) - # debug(f'{messages=}') - all_txt = '' - for d in self._gen(messages): - if d['choices'][0]['finish_reason'] == 'stop': - messages.append(self._build_assistant_message(all_txt)) - else: - all_txt += d['choices'][0]['delta']['content'] - yield d - self._set_session_messages(session, messages) - - async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): - for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - await asyncio.sleep(0) - yield d - - def generate(self, session, prompt, - image_path=None, - video_path=None, - audio_path=None, - sys_prompt=None): - for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - if d['choices'][0]['finish_reason'] == 'stop': - return d - def stream_generate(self, session, prompt, - image_path=None, - video_path=None, - audio_path=None, - sys_prompt=None): - for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - s = f'data: {json.dumps(d)}\n' - yield s - - async def async_generate(self, session, prompt, - image_path=None, - video_path=None, - audio_path=None, - sys_prompt=None): - async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - await asyncio.sleep(0) - if d['choices'][0]['finish_reason'] == 'stop': - return d - - async def async_stream_generate(self, session, prompt, - image_path=None, - video_path=None, - audio_path=None, - sys_prompt=None): - async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - s = f'data: {json.dumps(d)}\n' - yield s - yield 'data: [DONE]' - def build_kwargs(self, inputs, streamer): generate_kwargs = dict( **inputs, @@ -184,63 +115,70 @@ class BaseChatLLM: d['input_tokens'] = input_len yield d -class T2TChatLLM(BaseChatLLM): - def _build_assistant_message(self, prompt): - return { - "role":"assistant", - "content":prompt - } - - def _build_sys_message(self, prompt): - return { - "role":"system", - "content": prompt - } - - def _build_user_message(self, prompt, **kw): - return { - "role":"user", - "content": prompt - } - -class MMChatLLM(BaseChatLLM): - """ multiple modal chat LLM """ - def _build_assistant_message(self, prompt): - return { - "role":"assistant", - "content":[{"type": "text", "text": prompt}] - } - - def _build_sys_message(self, prompt): - return { - "role":"system", - "content":[{"type": "text", "text": prompt}] - } - - def _build_user_message(self, prompt, image_path=None, - video_path=None, audio_path=None): - contents = [ - { - "type":"text", "text": prompt - } - ] - if image_path: - contents.append({ - "type": "image", - "image": image_path - }) - if video_path: - contents.append({ - "type": "video", - "video":video_path - }) - if audio_path: - contents.append({ - "tyoe": "audio", - "audio": audio_path - }) - return { - "role": "user", - "content": contents - } + async def async_gen(self, messages): + async for d in stream.iterate(self._gen(messages)): + yield d + async def chat_completion_stream(self, messages): + async for d in self.async_gen(messages): + if d['choices'][0]['finish_reason']: + d['usage'] = { + 'prompt_tokens': d['input_tokens'], + 'completion_tokens': d['output_tokens'], + 'total_tokens': d['input_tokens'] + d['output_tokens'] + } + s = f'data: {json.dumps(d)}\n' + yield s + yield 'data: [DONE]\n' + + def reference(self, messages): + t1 = time() + inputs = self._messages2inputs(messages) + input_len = inputs["input_ids"].shape[-1] + streamer = self.get_streamer() + kwargs = self.build_kwargs(inputs, streamer) + thread = threading.Thread(target=self.model.generate, + kwargs=kwargs) + thread.start() + txt = '' + i = 0 + for d in self.output_generator(streamer): + if i == 0: + i = 1 + t1 = time() + if d['choices'][0]['finish_reason'] != 'stop': + txt += d['choices'][0]['delta']['content'] + else: + i_tokens = d['input_tokens'] + o_tokens = d['output_tokens'] + + t2 = time() + return { + 'id': f'chatcmpl-{getID()}', + "object":"chat.completion", + "created":t1, + "model":self.model_id, + "response_time": t2 - t1, + "finish_time": t3 - t1, + "output_token": output_tokens, + "choices":[ + { + "index":0, + "message":{ + "role": "assistant", + "content": txt + }, + "finish_reason":"stop" + } + ], + "usage": { + "prompt_tokens": i_tokens, + "completion_tokens": o_tokens, + "total_tokens": i_tokens + o_tokens + } + } + + async def chat_completion(self, messages): + f = awaitify(self.reference) + return await f(messages) + diff --git a/llmengine/gemma3_it.py b/llmengine/gemma3_it.py index 3936fd2..252e433 100644 --- a/llmengine/gemma3_it.py +++ b/llmengine/gemma3_it.py @@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter from PIL import Image import requests import torch -from llmengine.base_chat_llm import MMChatLLM, llm_register +from llmengine.base_chat_llm import BaseChatLLM, llm_register -class Gemma3LLM(MMChatLLM): +class Gemma3LLM(BaseChatLLM): def __init__(self, model_id): self.model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto" diff --git a/llmengine/medgemma3_it.py b/llmengine/medgemma3_it.py index db3d73a..d282b1b 100644 --- a/llmengine/medgemma3_it.py +++ b/llmengine/medgemma3_it.py @@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image import requests import torch -from llmengine.base_chat_llm import MMChatLLM, llm_register +from llmengine.base_chat_llm import BaseChatLLM, llm_register model_id = "google/medgemma-4b-it" -class MedgemmaLLM(MMChatLLM): +class MedgemmaLLM(BaseChatLLM): def __init__(self, model_id): self.model = AutoModelForImageTextToText.from_pretrained( model_id, diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index a6585d5..ad0a1c2 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -7,9 +7,9 @@ from ahserver.serverenv import get_serverenv from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import torch -from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register +from llmengine.base_chat_llm import BaseChatLLM, llm_register -class Qwen3LLM(T2TChatLLM): +class Qwen3LLM(BaseChatLLM): def __init__(self, model_id): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( diff --git a/llmengine/server.py b/llmengine/server.py index 83af5d2..c429712 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -21,23 +21,17 @@ def init(): rf.register('chat_completions', chat_completions) async def chat_completions(request, params_kw, *params, **kw): + se = ServerEnv() + engine = se.engine async def gor(): - se = ServerEnv() - engine = se.engine - session = await get_session(request) - kwargs = { - } - if params_kw.image_path: - kwargs['image_path'] = fs.reapPath(params_kw.image_path) - if params_kw.video_path: - kwargs['video_path'] = fs.reapPath(params_kw.video_path) - if params_kw.audio_path: - kwargs['audio_path'] = fs.reapPath(params_kw.audio_path) - async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs): + async for d in engine.chat_completion_stream(params_kw.messages): debug(f'{d=}') yield d - return await stream_response(request, gor) + if params_kw.stream: + return await stream_response(request, gor) + else: + return await engine.chat_completion(params_kw.messages) def main(): parser = argparse.ArgumentParser(prog="Sage")