From f3ce845388cf8a183a1f83f837dd9f1812ace607 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Tue, 10 Jun 2025 06:46:12 +0000 Subject: [PATCH] bugfix --- llmengine/base_chat_llm.py | 49 ++++++++++++++++---------------------- llmengine/qwen3.py | 2 ++ llmengine/server.py | 44 +++++++++++++++++++++------------- 3 files changed, 50 insertions(+), 45 deletions(-) diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index cd4b6c2..1381b72 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -1,8 +1,11 @@ import threading +import asyncio +import json import torch from time import time from transformers import TextIteratorStreamer from appPublic.log import debug +from appPublic.worker import awaitify class BaseChatLLM: def get_session_key(self): @@ -53,11 +56,7 @@ class BaseChatLLM: # debug(f'{all_txt=}, {d=}') yield d - def _generator(self, session, prompt, - image_path=None, - video_path=None, - audio_path=None, - sys_prompt=None): + def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): messages = self._get_session_messages(session) if sys_prompt: messages.append(self._build_sys_message(sys_prompt)) @@ -65,21 +64,21 @@ class BaseChatLLM: # debug(f'{messages=}') for d in self._gen(messages): if d['done']: - # debug(f'++++++++++++++{d=}') messages.append(self._build_assistant_message(d['text'])) yield d self._set_session_messages(session, messages) + async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt): + for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): + await asyncio.sleep(0) + yield d + def generate(self, session, prompt, image_path=None, video_path=None, audio_path=None, sys_prompt=None): - for d in self._generator(session, prompt, - image_path=image_path, - video_path=video_path, - audio_path=audio_path, - sys_prompt=sys_prompt): + for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): if d['done']: return d def stream_generate(self, session, prompt, @@ -87,34 +86,28 @@ class BaseChatLLM: video_path=None, audio_path=None, sys_prompt=None): - for d in self._generator(session, prompt, - image_path=image_path, - video_path=video_path, - audio_path=audio_path, - sys_prompt=sys_prompt): - yield d + for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt): + s = f'data {json.dumps(d)}\n' + yield s async def async_generate(self, session, prompt, image_path=None, video_path=None, audio_path=None, sys_prompt=None): - return self.generate(session, prompt, - image_path=image_path, - video_path=video_path, - audio_path=audio_path, - sys_prompt=sys_prompt) + async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): + await asyncio.sleep(0) + if d['done']: + return d + async def async_stream_generate(self, session, prompt, image_path=None, video_path=None, audio_path=None, sys_prompt=None): - for d in self._generator(session, prompt, - image_path=image_path, - video_path=video_path, - audio_path=audio_path, - sys_prompt=sys_prompt): - yield d + async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt): + s = f'data {json.dumps(d)}\n' + yield s def build_kwargs(self, inputs, streamer): generate_kwargs = dict( diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index cd7a96f..7053e61 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -2,6 +2,7 @@ # pip install accelerate from appPublic.worker import awaitify +from appPublic.log import debug from ahserver.serverenv import get_serverenv from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image @@ -29,6 +30,7 @@ class Qwen3LLM(T2TChatLLM): return generate_kwargs def _messages2inputs(self, messages): + debug(f'{messages=}') text = self.tokenizer.apply_chat_template( messages, tokenize=False, diff --git a/llmengine/server.py b/llmengine/server.py index c29d625..bf4c2a5 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -1,28 +1,36 @@ from traceback import format_exc +import os import sys -import argsparse -from llmengine.gemma3_it import Gemma3ChatLLM -from llmengine.qwen3 import Qwen3ChatLLM -from llmengine.medgemma3_it import Medgemma3ChatLLM -from ahserver.webapp import server -from from aiohttp_session import get_session +import argparse +from llmengine.gemma3_it import Gemma3LLM +from llmengine.qwen3 import Qwen3LLM +from llmengine.medgemma3_it import MedgemmaLLM + +from appPublic.registerfunction import RegisterFunction +from appPublic.log import debug +from ahserver.serverenv import ServerEnv +from ahserver.globalEnv import stream_response +from ahserver.webapp import webserver + +from aiohttp_session import get_session model_pathMap = { - "/share/models/Qwen/Qwen3-32B": Qwen3ChatLLM, - "/share/models/google/gemma-3-4b-it": Gemma3ChatLLM, - "/share/models/google/medgemma-4b-it": Medgemma3ChatLLM + "/share/models/Qwen/Qwen3-32B": Qwen3LLM, + "/share/models/google/gemma-3-4b-it": Gemma3LLM, + "/share/models/google/medgemma-4b-it": MedgemmaLLM } def register(model_path, Klass): model_pathMap[model_path] = Klass -def init(self): +def init(): rf = RegisterFunction() rf.register('chat_completions', chat_completions) async def chat_completions(request, params_kw, *params, **kw): + debug(f'{params_kw=}, {params=}, {kw=}') async def gor(): se = ServerEnv() engine = se.chat_engine - session = await get_session() + session = await get_session(request) kwargs = { } if params_kw.image_path: @@ -32,25 +40,27 @@ async def chat_completions(request, params_kw, *params, **kw): if params_kw.audio_path: kwargs['audio_path'] = fs.reapPath(params_kw.audio_path) async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs): + debug(f'{d=}') yield d return await stream_response(request, gor) def main(): parser = argparse.ArgumentParser(prog="Sage") - parser.add_argument('-w', '--workdir') - parser.add_argument('-p', '--port') + parser.add_argument('-w', '--workdir') + parser.add_argument('-p', '--port') parser.add_argument('model_path') - args = parser.parse_args() - Klass = model_pathMap.get(args.model_path) + args = parser.parse_args() + Klass = model_pathMap.get(args.model_path) if Klass is None: e = Exception(f'{model_path} has not mapping to a model class') exception(f'{e}, {format_exc()}') raise e + se = ServerEnv() se.chat_engine = Klass(args.model_path) workdir = args.workdir or os.getcwd() - port = args.port - server(init_func, workdir, port) + port = args.port + webserver(init, workdir, port) if __name__ == '__main__': main()