This commit is contained in:
yumoqing 2025-06-10 06:46:12 +00:00
parent ca11f86ee4
commit f3ce845388
3 changed files with 50 additions and 45 deletions

View File

@ -1,8 +1,11 @@
import threading
import asyncio
import json
import torch
from time import time
from transformers import TextIteratorStreamer
from appPublic.log import debug
from appPublic.worker import awaitify
class BaseChatLLM:
def get_session_key(self):
@ -53,11 +56,7 @@ class BaseChatLLM:
# debug(f'{all_txt=}, {d=}')
yield d
def _generator(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
messages = self._get_session_messages(session)
if sys_prompt:
messages.append(self._build_sys_message(sys_prompt))
@ -65,21 +64,21 @@ class BaseChatLLM:
# debug(f'{messages=}')
for d in self._gen(messages):
if d['done']:
# debug(f'++++++++++++++{d=}')
messages.append(self._build_assistant_message(d['text']))
yield d
self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt,
image_path=image_path,
video_path=video_path,
audio_path=audio_path,
sys_prompt=sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
if d['done']:
return d
def stream_generate(self, session, prompt,
@ -87,34 +86,28 @@ class BaseChatLLM:
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt,
image_path=image_path,
video_path=video_path,
audio_path=audio_path,
sys_prompt=sys_prompt):
yield d
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n'
yield s
async def async_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
return self.generate(session, prompt,
image_path=image_path,
video_path=video_path,
audio_path=audio_path,
sys_prompt=sys_prompt)
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
if d['done']:
return d
async def async_stream_generate(self, session, prompt,
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
for d in self._generator(session, prompt,
image_path=image_path,
video_path=video_path,
audio_path=audio_path,
sys_prompt=sys_prompt):
yield d
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
s = f'data {json.dumps(d)}\n'
yield s
def build_kwargs(self, inputs, streamer):
generate_kwargs = dict(

View File

@ -2,6 +2,7 @@
# pip install accelerate
from appPublic.worker import awaitify
from appPublic.log import debug
from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
@ -29,6 +30,7 @@ class Qwen3LLM(T2TChatLLM):
return generate_kwargs
def _messages2inputs(self, messages):
debug(f'{messages=}')
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,

View File

@ -1,28 +1,36 @@
from traceback import format_exc
import os
import sys
import argsparse
from llmengine.gemma3_it import Gemma3ChatLLM
from llmengine.qwen3 import Qwen3ChatLLM
from llmengine.medgemma3_it import Medgemma3ChatLLM
from ahserver.webapp import server
from from aiohttp_session import get_session
import argparse
from llmengine.gemma3_it import Gemma3LLM
from llmengine.qwen3 import Qwen3LLM
from llmengine.medgemma3_it import MedgemmaLLM
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug
from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
from aiohttp_session import get_session
model_pathMap = {
"/share/models/Qwen/Qwen3-32B": Qwen3ChatLLM,
"/share/models/google/gemma-3-4b-it": Gemma3ChatLLM,
"/share/models/google/medgemma-4b-it": Medgemma3ChatLLM
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
"/share/models/google/medgemma-4b-it": MedgemmaLLM
}
def register(model_path, Klass):
model_pathMap[model_path] = Klass
def init(self):
def init():
rf = RegisterFunction()
rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw):
debug(f'{params_kw=}, {params=}, {kw=}')
async def gor():
se = ServerEnv()
engine = se.chat_engine
session = await get_session()
session = await get_session(request)
kwargs = {
}
if params_kw.image_path:
@ -32,25 +40,27 @@ async def chat_completions(request, params_kw, *params, **kw):
if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
debug(f'{d=}')
yield d
return await stream_response(request, gor)
def main():
parser = argparse.ArgumentParser(prog="Sage")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
parser.add_argument('model_path')
args = parser.parse_args()
Klass = model_pathMap.get(args.model_path)
args = parser.parse_args()
Klass = model_pathMap.get(args.model_path)
if Klass is None:
e = Exception(f'{model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}')
raise e
se = ServerEnv()
se.chat_engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd()
port = args.port
server(init_func, workdir, port)
port = args.port
webserver(init, workdir, port)
if __name__ == '__main__':
main()