This commit is contained in:
yumoqing 2025-06-10 06:46:12 +00:00
parent ca11f86ee4
commit f3ce845388
3 changed files with 50 additions and 45 deletions

View File

@ -1,8 +1,11 @@
import threading import threading
import asyncio
import json
import torch import torch
from time import time from time import time
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from appPublic.log import debug from appPublic.log import debug
from appPublic.worker import awaitify
class BaseChatLLM: class BaseChatLLM:
def get_session_key(self): def get_session_key(self):
@ -53,11 +56,7 @@ class BaseChatLLM:
# debug(f'{all_txt=}, {d=}') # debug(f'{all_txt=}, {d=}')
yield d yield d
def _generator(self, session, prompt, def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
image_path=None,
video_path=None,
audio_path=None,
sys_prompt=None):
messages = self._get_session_messages(session) messages = self._get_session_messages(session)
if sys_prompt: if sys_prompt:
messages.append(self._build_sys_message(sys_prompt)) messages.append(self._build_sys_message(sys_prompt))
@ -65,21 +64,21 @@ class BaseChatLLM:
# debug(f'{messages=}') # debug(f'{messages=}')
for d in self._gen(messages): for d in self._gen(messages):
if d['done']: if d['done']:
# debug(f'++++++++++++++{d=}')
messages.append(self._build_assistant_message(d['text'])) messages.append(self._build_assistant_message(d['text']))
yield d yield d
self._set_session_messages(session, messages) self._set_session_messages(session, messages)
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
await asyncio.sleep(0)
yield d
def generate(self, session, prompt, def generate(self, session, prompt,
image_path=None, image_path=None,
video_path=None, video_path=None,
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
for d in self._generator(session, prompt, for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
image_path=image_path,
video_path=video_path,
audio_path=audio_path,
sys_prompt=sys_prompt):
if d['done']: if d['done']:
return d return d
def stream_generate(self, session, prompt, def stream_generate(self, session, prompt,
@ -87,34 +86,28 @@ class BaseChatLLM:
video_path=None, video_path=None,
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
for d in self._generator(session, prompt, for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
image_path=image_path, s = f'data {json.dumps(d)}\n'
video_path=video_path, yield s
audio_path=audio_path,
sys_prompt=sys_prompt):
yield d
async def async_generate(self, session, prompt, async def async_generate(self, session, prompt,
image_path=None, image_path=None,
video_path=None, video_path=None,
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
return self.generate(session, prompt, async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
image_path=image_path, await asyncio.sleep(0)
video_path=video_path, if d['done']:
audio_path=audio_path, return d
sys_prompt=sys_prompt)
async def async_stream_generate(self, session, prompt, async def async_stream_generate(self, session, prompt,
image_path=None, image_path=None,
video_path=None, video_path=None,
audio_path=None, audio_path=None,
sys_prompt=None): sys_prompt=None):
for d in self._generator(session, prompt, async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
image_path=image_path, s = f'data {json.dumps(d)}\n'
video_path=video_path, yield s
audio_path=audio_path,
sys_prompt=sys_prompt):
yield d
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(

View File

@ -2,6 +2,7 @@
# pip install accelerate # pip install accelerate
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.log import debug
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
@ -29,6 +30,7 @@ class Qwen3LLM(T2TChatLLM):
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}')
text = self.tokenizer.apply_chat_template( text = self.tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,

View File

@ -1,28 +1,36 @@
from traceback import format_exc from traceback import format_exc
import os
import sys import sys
import argsparse import argparse
from llmengine.gemma3_it import Gemma3ChatLLM from llmengine.gemma3_it import Gemma3LLM
from llmengine.qwen3 import Qwen3ChatLLM from llmengine.qwen3 import Qwen3LLM
from llmengine.medgemma3_it import Medgemma3ChatLLM from llmengine.medgemma3_it import MedgemmaLLM
from ahserver.webapp import server
from from aiohttp_session import get_session from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug
from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
from aiohttp_session import get_session
model_pathMap = { model_pathMap = {
"/share/models/Qwen/Qwen3-32B": Qwen3ChatLLM, "/share/models/Qwen/Qwen3-32B": Qwen3LLM,
"/share/models/google/gemma-3-4b-it": Gemma3ChatLLM, "/share/models/google/gemma-3-4b-it": Gemma3LLM,
"/share/models/google/medgemma-4b-it": Medgemma3ChatLLM "/share/models/google/medgemma-4b-it": MedgemmaLLM
} }
def register(model_path, Klass): def register(model_path, Klass):
model_pathMap[model_path] = Klass model_pathMap[model_path] = Klass
def init(self): def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('chat_completions', chat_completions) rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw): async def chat_completions(request, params_kw, *params, **kw):
debug(f'{params_kw=}, {params=}, {kw=}')
async def gor(): async def gor():
se = ServerEnv() se = ServerEnv()
engine = se.chat_engine engine = se.chat_engine
session = await get_session() session = await get_session(request)
kwargs = { kwargs = {
} }
if params_kw.image_path: if params_kw.image_path:
@ -32,25 +40,27 @@ async def chat_completions(request, params_kw, *params, **kw):
if params_kw.audio_path: if params_kw.audio_path:
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path) kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs): async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
debug(f'{d=}')
yield d yield d
return await stream_response(request, gor) return await stream_response(request, gor)
def main(): def main():
parser = argparse.ArgumentParser(prog="Sage") parser = argparse.ArgumentParser(prog="Sage")
parser.add_argument('-w', '--workdir') parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port') parser.add_argument('-p', '--port')
parser.add_argument('model_path') parser.add_argument('model_path')
args = parser.parse_args() args = parser.parse_args()
Klass = model_pathMap.get(args.model_path) Klass = model_pathMap.get(args.model_path)
if Klass is None: if Klass is None:
e = Exception(f'{model_path} has not mapping to a model class') e = Exception(f'{model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}') exception(f'{e}, {format_exc()}')
raise e raise e
se = ServerEnv()
se.chat_engine = Klass(args.model_path) se.chat_engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd() workdir = args.workdir or os.getcwd()
port = args.port port = args.port
server(init_func, workdir, port) webserver(init, workdir, port)
if __name__ == '__main__': if __name__ == '__main__':
main() main()