diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index 1381b72..b0928c4 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -6,6 +6,15 @@ from time import time from transformers import TextIteratorStreamer from appPublic.log import debug from appPublic.worker import awaitify +from appPublic.uniqueID import getID + +model_pathMap = { +} +def llm_register(model_path, Klass): + model_pathMap[model_path] = Klass + +def get_llm_class(model_path): + return model_pathMap.get(model_path) class BaseChatLLM: def get_session_key(self): @@ -31,6 +40,7 @@ class BaseChatLLM: all_txt = '' t1 = time() i = 0 + id = f'chatllm-{getID}' for txt in streamer: if txt == '': continue @@ -39,22 +49,44 @@ class BaseChatLLM: i += 1 all_txt += txt yield { - 'done': False, - 'text': txt + "id":id, + "object":"chat.completion.chunk", + "created":time.time(), + "model":self.model_id, + "choices":[ + { + "index":0, + "delta":{ + "content":txt + }, + "logprobs":null, + "finish_reason":null + } + ] } t3 = time() t = all_txt unk = self.tokenizer(t, return_tensors="pt") output_tokens = len(unk["input_ids"][0]) - d = { - 'done': True, - 'text': all_txt, - 'response_time': t2 - t1, - 'finish_time': t3 - t1, - 'output_token': output_tokens + yield { + "id":id, + "object":"chat.completion.chunk", + "created":time.time(), + "model":self.model_id, + "response_time": t2 - t1, + "finish_time": t3 - t1, + "output_token": output_tokens, + "choices":[ + { + "index":0, + "delta":{ + "content":"" + }, + "logprobs":null, + "finish_reason":"stop" + } + ] } - # debug(f'{all_txt=}, {d=}') - yield d def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): messages = self._get_session_messages(session) @@ -63,7 +95,7 @@ class BaseChatLLM: messages.append(self._build_user_message(prompt, image_path=image_path)) # debug(f'{messages=}') for d in self._gen(messages): - if d['done']: + if d['choices'][0]['finish_reason'] == 'stop': messages.append(self._build_assistant_message(d['text'])) yield d self._set_session_messages(session, messages) @@ -79,7 +111,7 @@ class BaseChatLLM: audio_path=None, sys_prompt=None): for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - if d['done']: + if d['choices'][0]['finish_reason'] == 'stop': return d def stream_generate(self, session, prompt, image_path=None, @@ -87,7 +119,7 @@ class BaseChatLLM: audio_path=None, sys_prompt=None): for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - s = f'data {json.dumps(d)}\n' + s = f'data: {json.dumps(d)}\n' yield s async def async_generate(self, session, prompt, @@ -97,7 +129,7 @@ class BaseChatLLM: sys_prompt=None): async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): await asyncio.sleep(0) - if d['done']: + if d['choices'][0]['finish_reason'] == 'stop': return d async def async_stream_generate(self, session, prompt, @@ -106,8 +138,9 @@ class BaseChatLLM: audio_path=None, sys_prompt=None): async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): - s = f'data {json.dumps(d)}\n' + s = f'data: {json.dumps(d)}\n' yield s + yield 'data: [done]' def build_kwargs(self, inputs, streamer): generate_kwargs = dict( @@ -135,9 +168,8 @@ class BaseChatLLM: kwargs=kwargs) thread.start() for d in self.output_generator(streamer): - if d['done']: + if d['choices'][0]['finish_reason'] == 'stop': d['input_tokens'] = input_len - # debug(f'{d=}\n') yield d class T2TChatLLM(BaseChatLLM): diff --git a/llmengine/gemma3_it.py b/llmengine/gemma3_it.py index 382d17b..635ca98 100644 --- a/llmengine/gemma3_it.py +++ b/llmengine/gemma3_it.py @@ -9,7 +9,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter from PIL import Image import requests import torch -from llmengine.base_chat_llm import MMChatLLM +from llmengine.base_chat_llm import MMChatLLM, llm_register class Gemma3LLM(MMChatLLM): def __init__(self, model_id): @@ -21,6 +21,8 @@ class Gemma3LLM(MMChatLLM): self.messages = [] self.model_id = model_id +llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM) + if __name__ == '__main__': gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') session = {} diff --git a/llmengine/medgemma3_it.py b/llmengine/medgemma3_it.py index 4aeec0b..56c22f2 100644 --- a/llmengine/medgemma3_it.py +++ b/llmengine/medgemma3_it.py @@ -4,7 +4,7 @@ from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image import requests import torch -from llmengine.base_chat_llm import MMChatLLM +from llmengine.base_chat_llm import MMChatLLM, llm_register model_id = "google/medgemma-4b-it" @@ -29,6 +29,8 @@ class MedgemmaLLM(MMChatLLM): ).to(self.model.device, dtype=torch.bfloat16) return inputs +llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM) + if __name__ == '__main__': med = MedgemmaLLM('/share/models/google/medgemma-4b-it') session = {} diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index 7053e61..72ffe7a 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -7,7 +7,7 @@ from ahserver.serverenv import get_serverenv from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import torch -from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM +from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register class Qwen3LLM(T2TChatLLM): def __init__(self, model_id): @@ -39,6 +39,8 @@ class Qwen3LLM(T2TChatLLM): ) return self.tokenizer([text], return_tensors="pt").to(self.model.device) +llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM) + if __name__ == '__main__': q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B') session = {} diff --git a/llmengine/server.py b/llmengine/server.py index bf4c2a5..beae951 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -2,9 +2,11 @@ from traceback import format_exc import os import sys import argparse +from llmengine.base_chat_llm import get_llm_class from llmengine.gemma3_it import Gemma3LLM from llmengine.qwen3 import Qwen3LLM from llmengine.medgemma3_it import MedgemmaLLM +from llmengine.devstral import DevstralLLM from appPublic.registerfunction import RegisterFunction from appPublic.log import debug @@ -13,13 +15,6 @@ from ahserver.globalEnv import stream_response from ahserver.webapp import webserver from aiohttp_session import get_session -model_pathMap = { - "/share/models/Qwen/Qwen3-32B": Qwen3LLM, - "/share/models/google/gemma-3-4b-it": Gemma3LLM, - "/share/models/google/medgemma-4b-it": MedgemmaLLM -} -def register(model_path, Klass): - model_pathMap[model_path] = Klass def init(): rf = RegisterFunction() @@ -51,7 +46,7 @@ def main(): parser.add_argument('-p', '--port') parser.add_argument('model_path') args = parser.parse_args() - Klass = model_pathMap.get(args.model_path) + Klass = get_llm_class(args.model_path) if Klass is None: e = Exception(f'{model_path} has not mapping to a model class') exception(f'{e}, {format_exc()}') diff --git a/pyproject.toml b/pyproject.toml index 3a428f9..f2df66d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ license = {text = "MIT"} dependencies = [ "torch", "transformers", + "mistral-common", "accelerate" ]