bugfix
This commit is contained in:
parent
ca11f86ee4
commit
f3ce845388
@ -1,8 +1,11 @@
|
|||||||
import threading
|
import threading
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from time import time
|
from time import time
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
def get_session_key(self):
|
def get_session_key(self):
|
||||||
@ -53,11 +56,7 @@ class BaseChatLLM:
|
|||||||
# debug(f'{all_txt=}, {d=}')
|
# debug(f'{all_txt=}, {d=}')
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
def _generator(self, session, prompt,
|
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
messages = self._get_session_messages(session)
|
messages = self._get_session_messages(session)
|
||||||
if sys_prompt:
|
if sys_prompt:
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
messages.append(self._build_sys_message(sys_prompt))
|
||||||
@ -65,21 +64,21 @@ class BaseChatLLM:
|
|||||||
# debug(f'{messages=}')
|
# debug(f'{messages=}')
|
||||||
for d in self._gen(messages):
|
for d in self._gen(messages):
|
||||||
if d['done']:
|
if d['done']:
|
||||||
# debug(f'++++++++++++++{d=}')
|
|
||||||
messages.append(self._build_assistant_message(d['text']))
|
messages.append(self._build_assistant_message(d['text']))
|
||||||
yield d
|
yield d
|
||||||
self._set_session_messages(session, messages)
|
self._set_session_messages(session, messages)
|
||||||
|
|
||||||
|
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
yield d
|
||||||
|
|
||||||
def generate(self, session, prompt,
|
def generate(self, session, prompt,
|
||||||
image_path=None,
|
image_path=None,
|
||||||
video_path=None,
|
video_path=None,
|
||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
for d in self._generator(session, prompt,
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
image_path=image_path,
|
|
||||||
video_path=video_path,
|
|
||||||
audio_path=audio_path,
|
|
||||||
sys_prompt=sys_prompt):
|
|
||||||
if d['done']:
|
if d['done']:
|
||||||
return d
|
return d
|
||||||
def stream_generate(self, session, prompt,
|
def stream_generate(self, session, prompt,
|
||||||
@ -87,34 +86,28 @@ class BaseChatLLM:
|
|||||||
video_path=None,
|
video_path=None,
|
||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
for d in self._generator(session, prompt,
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
image_path=image_path,
|
s = f'data {json.dumps(d)}\n'
|
||||||
video_path=video_path,
|
yield s
|
||||||
audio_path=audio_path,
|
|
||||||
sys_prompt=sys_prompt):
|
|
||||||
yield d
|
|
||||||
|
|
||||||
async def async_generate(self, session, prompt,
|
async def async_generate(self, session, prompt,
|
||||||
image_path=None,
|
image_path=None,
|
||||||
video_path=None,
|
video_path=None,
|
||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
return self.generate(session, prompt,
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
image_path=image_path,
|
await asyncio.sleep(0)
|
||||||
video_path=video_path,
|
if d['done']:
|
||||||
audio_path=audio_path,
|
return d
|
||||||
sys_prompt=sys_prompt)
|
|
||||||
async def async_stream_generate(self, session, prompt,
|
async def async_stream_generate(self, session, prompt,
|
||||||
image_path=None,
|
image_path=None,
|
||||||
video_path=None,
|
video_path=None,
|
||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
for d in self._generator(session, prompt,
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
image_path=image_path,
|
s = f'data {json.dumps(d)}\n'
|
||||||
video_path=video_path,
|
yield s
|
||||||
audio_path=audio_path,
|
|
||||||
sys_prompt=sys_prompt):
|
|
||||||
yield d
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
# pip install accelerate
|
# pip install accelerate
|
||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug
|
||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -29,6 +30,7 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
|
debug(f'{messages=}')
|
||||||
text = self.tokenizer.apply_chat_template(
|
text = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
|
@ -1,28 +1,36 @@
|
|||||||
from traceback import format_exc
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argsparse
|
import argparse
|
||||||
from llmengine.gemma3_it import Gemma3ChatLLM
|
from llmengine.gemma3_it import Gemma3LLM
|
||||||
from llmengine.qwen3 import Qwen3ChatLLM
|
from llmengine.qwen3 import Qwen3LLM
|
||||||
from llmengine.medgemma3_it import Medgemma3ChatLLM
|
from llmengine.medgemma3_it import MedgemmaLLM
|
||||||
from ahserver.webapp import server
|
|
||||||
from from aiohttp_session import get_session
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
model_pathMap = {
|
model_pathMap = {
|
||||||
"/share/models/Qwen/Qwen3-32B": Qwen3ChatLLM,
|
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
|
||||||
"/share/models/google/gemma-3-4b-it": Gemma3ChatLLM,
|
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
|
||||||
"/share/models/google/medgemma-4b-it": Medgemma3ChatLLM
|
"/share/models/google/medgemma-4b-it": MedgemmaLLM
|
||||||
}
|
}
|
||||||
def register(model_path, Klass):
|
def register(model_path, Klass):
|
||||||
model_pathMap[model_path] = Klass
|
model_pathMap[model_path] = Klass
|
||||||
|
|
||||||
def init(self):
|
def init():
|
||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
rf.register('chat_completions', chat_completions)
|
rf.register('chat_completions', chat_completions)
|
||||||
|
|
||||||
async def chat_completions(request, params_kw, *params, **kw):
|
async def chat_completions(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw=}, {params=}, {kw=}')
|
||||||
async def gor():
|
async def gor():
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.chat_engine
|
engine = se.chat_engine
|
||||||
session = await get_session()
|
session = await get_session(request)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
}
|
}
|
||||||
if params_kw.image_path:
|
if params_kw.image_path:
|
||||||
@ -32,25 +40,27 @@ async def chat_completions(request, params_kw, *params, **kw):
|
|||||||
if params_kw.audio_path:
|
if params_kw.audio_path:
|
||||||
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
||||||
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
||||||
|
debug(f'{d=}')
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
return await stream_response(request, gor)
|
return await stream_response(request, gor)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog="Sage")
|
parser = argparse.ArgumentParser(prog="Sage")
|
||||||
parser.add_argument('-w', '--workdir')
|
parser.add_argument('-w', '--workdir')
|
||||||
parser.add_argument('-p', '--port')
|
parser.add_argument('-p', '--port')
|
||||||
parser.add_argument('model_path')
|
parser.add_argument('model_path')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
Klass = model_pathMap.get(args.model_path)
|
Klass = model_pathMap.get(args.model_path)
|
||||||
if Klass is None:
|
if Klass is None:
|
||||||
e = Exception(f'{model_path} has not mapping to a model class')
|
e = Exception(f'{model_path} has not mapping to a model class')
|
||||||
exception(f'{e}, {format_exc()}')
|
exception(f'{e}, {format_exc()}')
|
||||||
raise e
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
se.chat_engine = Klass(args.model_path)
|
se.chat_engine = Klass(args.model_path)
|
||||||
workdir = args.workdir or os.getcwd()
|
workdir = args.workdir or os.getcwd()
|
||||||
port = args.port
|
port = args.port
|
||||||
server(init_func, workdir, port)
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user