bugfix
This commit is contained in:
parent
9c72a83189
commit
ca11f86ee4
57
llmengine/server.py
Normal file
57
llmengine/server.py
Normal file
@ -0,0 +1,57 @@
|
||||
from traceback import format_exc
|
||||
import sys
|
||||
import argsparse
|
||||
from llmengine.gemma3_it import Gemma3ChatLLM
|
||||
from llmengine.qwen3 import Qwen3ChatLLM
|
||||
from llmengine.medgemma3_it import Medgemma3ChatLLM
|
||||
from ahserver.webapp import server
|
||||
from from aiohttp_session import get_session
|
||||
model_pathMap = {
|
||||
"/share/models/Qwen/Qwen3-32B": Qwen3ChatLLM,
|
||||
"/share/models/google/gemma-3-4b-it": Gemma3ChatLLM,
|
||||
"/share/models/google/medgemma-4b-it": Medgemma3ChatLLM
|
||||
}
|
||||
def register(model_path, Klass):
|
||||
model_pathMap[model_path] = Klass
|
||||
|
||||
def init(self):
|
||||
rf = RegisterFunction()
|
||||
rf.register('chat_completions', chat_completions)
|
||||
|
||||
async def chat_completions(request, params_kw, *params, **kw):
|
||||
async def gor():
|
||||
se = ServerEnv()
|
||||
engine = se.chat_engine
|
||||
session = await get_session()
|
||||
kwargs = {
|
||||
}
|
||||
if params_kw.image_path:
|
||||
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
|
||||
if params_kw.video_path:
|
||||
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
|
||||
if params_kw.audio_path:
|
||||
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
||||
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
||||
yield d
|
||||
|
||||
return await stream_response(request, gor)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="Sage")
|
||||
parser.add_argument('-w', '--workdir')
|
||||
parser.add_argument('-p', '--port')
|
||||
parser.add_argument('model_path')
|
||||
args = parser.parse_args()
|
||||
Klass = model_pathMap.get(args.model_path)
|
||||
if Klass is None:
|
||||
e = Exception(f'{model_path} has not mapping to a model class')
|
||||
exception(f'{e}, {format_exc()}')
|
||||
raise e
|
||||
se.chat_engine = Klass(args.model_path)
|
||||
workdir = args.workdir or os.getcwd()
|
||||
port = args.port
|
||||
server(init_func, workdir, port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user