bugfix
This commit is contained in:
parent
f3ce845388
commit
50a14e7c5c
@ -6,6 +6,15 @@ from time import time
|
||||
from transformers import TextIteratorStreamer
|
||||
from appPublic.log import debug
|
||||
from appPublic.worker import awaitify
|
||||
from appPublic.uniqueID import getID
|
||||
|
||||
model_pathMap = {
|
||||
}
|
||||
def llm_register(model_path, Klass):
|
||||
model_pathMap[model_path] = Klass
|
||||
|
||||
def get_llm_class(model_path):
|
||||
return model_pathMap.get(model_path)
|
||||
|
||||
class BaseChatLLM:
|
||||
def get_session_key(self):
|
||||
@ -31,6 +40,7 @@ class BaseChatLLM:
|
||||
all_txt = ''
|
||||
t1 = time()
|
||||
i = 0
|
||||
id = f'chatllm-{getID}'
|
||||
for txt in streamer:
|
||||
if txt == '':
|
||||
continue
|
||||
@ -39,22 +49,44 @@ class BaseChatLLM:
|
||||
i += 1
|
||||
all_txt += txt
|
||||
yield {
|
||||
'done': False,
|
||||
'text': txt
|
||||
"id":id,
|
||||
"object":"chat.completion.chunk",
|
||||
"created":time.time(),
|
||||
"model":self.model_id,
|
||||
"choices":[
|
||||
{
|
||||
"index":0,
|
||||
"delta":{
|
||||
"content":txt
|
||||
},
|
||||
"logprobs":null,
|
||||
"finish_reason":null
|
||||
}
|
||||
]
|
||||
}
|
||||
t3 = time()
|
||||
t = all_txt
|
||||
unk = self.tokenizer(t, return_tensors="pt")
|
||||
output_tokens = len(unk["input_ids"][0])
|
||||
d = {
|
||||
'done': True,
|
||||
'text': all_txt,
|
||||
'response_time': t2 - t1,
|
||||
'finish_time': t3 - t1,
|
||||
'output_token': output_tokens
|
||||
yield {
|
||||
"id":id,
|
||||
"object":"chat.completion.chunk",
|
||||
"created":time.time(),
|
||||
"model":self.model_id,
|
||||
"response_time": t2 - t1,
|
||||
"finish_time": t3 - t1,
|
||||
"output_token": output_tokens,
|
||||
"choices":[
|
||||
{
|
||||
"index":0,
|
||||
"delta":{
|
||||
"content":""
|
||||
},
|
||||
"logprobs":null,
|
||||
"finish_reason":"stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
# debug(f'{all_txt=}, {d=}')
|
||||
yield d
|
||||
|
||||
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
messages = self._get_session_messages(session)
|
||||
@ -63,7 +95,7 @@ class BaseChatLLM:
|
||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||
# debug(f'{messages=}')
|
||||
for d in self._gen(messages):
|
||||
if d['done']:
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
messages.append(self._build_assistant_message(d['text']))
|
||||
yield d
|
||||
self._set_session_messages(session, messages)
|
||||
@ -79,7 +111,7 @@ class BaseChatLLM:
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
if d['done']:
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
return d
|
||||
def stream_generate(self, session, prompt,
|
||||
image_path=None,
|
||||
@ -87,7 +119,7 @@ class BaseChatLLM:
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
s = f'data {json.dumps(d)}\n'
|
||||
s = f'data: {json.dumps(d)}\n'
|
||||
yield s
|
||||
|
||||
async def async_generate(self, session, prompt,
|
||||
@ -97,7 +129,7 @@ class BaseChatLLM:
|
||||
sys_prompt=None):
|
||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
await asyncio.sleep(0)
|
||||
if d['done']:
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
return d
|
||||
|
||||
async def async_stream_generate(self, session, prompt,
|
||||
@ -106,8 +138,9 @@ class BaseChatLLM:
|
||||
audio_path=None,
|
||||
sys_prompt=None):
|
||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||
s = f'data {json.dumps(d)}\n'
|
||||
s = f'data: {json.dumps(d)}\n'
|
||||
yield s
|
||||
yield 'data: [done]'
|
||||
|
||||
def build_kwargs(self, inputs, streamer):
|
||||
generate_kwargs = dict(
|
||||
@ -135,9 +168,8 @@ class BaseChatLLM:
|
||||
kwargs=kwargs)
|
||||
thread.start()
|
||||
for d in self.output_generator(streamer):
|
||||
if d['done']:
|
||||
if d['choices'][0]['finish_reason'] == 'stop':
|
||||
d['input_tokens'] = input_len
|
||||
# debug(f'{d=}\n')
|
||||
yield d
|
||||
|
||||
class T2TChatLLM(BaseChatLLM):
|
||||
|
@ -9,7 +9,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import MMChatLLM
|
||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||
|
||||
class Gemma3LLM(MMChatLLM):
|
||||
def __init__(self, model_id):
|
||||
@ -21,6 +21,8 @@ class Gemma3LLM(MMChatLLM):
|
||||
self.messages = []
|
||||
self.model_id = model_id
|
||||
|
||||
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM)
|
||||
|
||||
if __name__ == '__main__':
|
||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||
session = {}
|
||||
|
@ -4,7 +4,7 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import MMChatLLM
|
||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||
|
||||
model_id = "google/medgemma-4b-it"
|
||||
|
||||
@ -29,6 +29,8 @@ class MedgemmaLLM(MMChatLLM):
|
||||
).to(self.model.device, dtype=torch.bfloat16)
|
||||
return inputs
|
||||
|
||||
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM)
|
||||
|
||||
if __name__ == '__main__':
|
||||
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
||||
session = {}
|
||||
|
@ -7,7 +7,7 @@ from ahserver.serverenv import get_serverenv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from PIL import Image
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
|
||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
||||
|
||||
class Qwen3LLM(T2TChatLLM):
|
||||
def __init__(self, model_id):
|
||||
@ -39,6 +39,8 @@ class Qwen3LLM(T2TChatLLM):
|
||||
)
|
||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||
|
||||
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM)
|
||||
|
||||
if __name__ == '__main__':
|
||||
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
||||
session = {}
|
||||
|
@ -2,9 +2,11 @@ from traceback import format_exc
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from llmengine.base_chat_llm import get_llm_class
|
||||
from llmengine.gemma3_it import Gemma3LLM
|
||||
from llmengine.qwen3 import Qwen3LLM
|
||||
from llmengine.medgemma3_it import MedgemmaLLM
|
||||
from llmengine.devstral import DevstralLLM
|
||||
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.log import debug
|
||||
@ -13,13 +15,6 @@ from ahserver.globalEnv import stream_response
|
||||
from ahserver.webapp import webserver
|
||||
|
||||
from aiohttp_session import get_session
|
||||
model_pathMap = {
|
||||
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
|
||||
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
|
||||
"/share/models/google/medgemma-4b-it": MedgemmaLLM
|
||||
}
|
||||
def register(model_path, Klass):
|
||||
model_pathMap[model_path] = Klass
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
@ -51,7 +46,7 @@ def main():
|
||||
parser.add_argument('-p', '--port')
|
||||
parser.add_argument('model_path')
|
||||
args = parser.parse_args()
|
||||
Klass = model_pathMap.get(args.model_path)
|
||||
Klass = get_llm_class(args.model_path)
|
||||
if Klass is None:
|
||||
e = Exception(f'{model_path} has not mapping to a model class')
|
||||
exception(f'{e}, {format_exc()}')
|
||||
|
@ -9,6 +9,7 @@ license = {text = "MIT"}
|
||||
dependencies = [
|
||||
"torch",
|
||||
"transformers",
|
||||
"mistral-common",
|
||||
"accelerate"
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user