bugfix
This commit is contained in:
parent
f3ce845388
commit
50a14e7c5c
@ -6,6 +6,15 @@ from time import time
|
|||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
|
||||||
|
model_pathMap = {
|
||||||
|
}
|
||||||
|
def llm_register(model_path, Klass):
|
||||||
|
model_pathMap[model_path] = Klass
|
||||||
|
|
||||||
|
def get_llm_class(model_path):
|
||||||
|
return model_pathMap.get(model_path)
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
def get_session_key(self):
|
def get_session_key(self):
|
||||||
@ -31,6 +40,7 @@ class BaseChatLLM:
|
|||||||
all_txt = ''
|
all_txt = ''
|
||||||
t1 = time()
|
t1 = time()
|
||||||
i = 0
|
i = 0
|
||||||
|
id = f'chatllm-{getID}'
|
||||||
for txt in streamer:
|
for txt in streamer:
|
||||||
if txt == '':
|
if txt == '':
|
||||||
continue
|
continue
|
||||||
@ -39,22 +49,44 @@ class BaseChatLLM:
|
|||||||
i += 1
|
i += 1
|
||||||
all_txt += txt
|
all_txt += txt
|
||||||
yield {
|
yield {
|
||||||
'done': False,
|
"id":id,
|
||||||
'text': txt
|
"object":"chat.completion.chunk",
|
||||||
|
"created":time.time(),
|
||||||
|
"model":self.model_id,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{
|
||||||
|
"content":txt
|
||||||
|
},
|
||||||
|
"logprobs":null,
|
||||||
|
"finish_reason":null
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
t3 = time()
|
t3 = time()
|
||||||
t = all_txt
|
t = all_txt
|
||||||
unk = self.tokenizer(t, return_tensors="pt")
|
unk = self.tokenizer(t, return_tensors="pt")
|
||||||
output_tokens = len(unk["input_ids"][0])
|
output_tokens = len(unk["input_ids"][0])
|
||||||
d = {
|
yield {
|
||||||
'done': True,
|
"id":id,
|
||||||
'text': all_txt,
|
"object":"chat.completion.chunk",
|
||||||
'response_time': t2 - t1,
|
"created":time.time(),
|
||||||
'finish_time': t3 - t1,
|
"model":self.model_id,
|
||||||
'output_token': output_tokens
|
"response_time": t2 - t1,
|
||||||
|
"finish_time": t3 - t1,
|
||||||
|
"output_token": output_tokens,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{
|
||||||
|
"content":""
|
||||||
|
},
|
||||||
|
"logprobs":null,
|
||||||
|
"finish_reason":"stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
# debug(f'{all_txt=}, {d=}')
|
|
||||||
yield d
|
|
||||||
|
|
||||||
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
messages = self._get_session_messages(session)
|
messages = self._get_session_messages(session)
|
||||||
@ -63,7 +95,7 @@ class BaseChatLLM:
|
|||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||||
# debug(f'{messages=}')
|
# debug(f'{messages=}')
|
||||||
for d in self._gen(messages):
|
for d in self._gen(messages):
|
||||||
if d['done']:
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
messages.append(self._build_assistant_message(d['text']))
|
messages.append(self._build_assistant_message(d['text']))
|
||||||
yield d
|
yield d
|
||||||
self._set_session_messages(session, messages)
|
self._set_session_messages(session, messages)
|
||||||
@ -79,7 +111,7 @@ class BaseChatLLM:
|
|||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
if d['done']:
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
return d
|
return d
|
||||||
def stream_generate(self, session, prompt,
|
def stream_generate(self, session, prompt,
|
||||||
image_path=None,
|
image_path=None,
|
||||||
@ -87,7 +119,7 @@ class BaseChatLLM:
|
|||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
s = f'data {json.dumps(d)}\n'
|
s = f'data: {json.dumps(d)}\n'
|
||||||
yield s
|
yield s
|
||||||
|
|
||||||
async def async_generate(self, session, prompt,
|
async def async_generate(self, session, prompt,
|
||||||
@ -97,7 +129,7 @@ class BaseChatLLM:
|
|||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
if d['done']:
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
return d
|
return d
|
||||||
|
|
||||||
async def async_stream_generate(self, session, prompt,
|
async def async_stream_generate(self, session, prompt,
|
||||||
@ -106,8 +138,9 @@ class BaseChatLLM:
|
|||||||
audio_path=None,
|
audio_path=None,
|
||||||
sys_prompt=None):
|
sys_prompt=None):
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
s = f'data {json.dumps(d)}\n'
|
s = f'data: {json.dumps(d)}\n'
|
||||||
yield s
|
yield s
|
||||||
|
yield 'data: [done]'
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
@ -135,9 +168,8 @@ class BaseChatLLM:
|
|||||||
kwargs=kwargs)
|
kwargs=kwargs)
|
||||||
thread.start()
|
thread.start()
|
||||||
for d in self.output_generator(streamer):
|
for d in self.output_generator(streamer):
|
||||||
if d['done']:
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
d['input_tokens'] = input_len
|
d['input_tokens'] = input_len
|
||||||
# debug(f'{d=}\n')
|
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
class T2TChatLLM(BaseChatLLM):
|
class T2TChatLLM(BaseChatLLM):
|
||||||
|
@ -9,7 +9,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import MMChatLLM
|
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||||
|
|
||||||
class Gemma3LLM(MMChatLLM):
|
class Gemma3LLM(MMChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
@ -21,6 +21,8 @@ class Gemma3LLM(MMChatLLM):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
|
llm_register("/share/models/google/gemma-3-4b-it", Gemma3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
session = {}
|
session = {}
|
||||||
|
@ -4,7 +4,7 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import MMChatLLM
|
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||||
|
|
||||||
model_id = "google/medgemma-4b-it"
|
model_id = "google/medgemma-4b-it"
|
||||||
|
|
||||||
@ -29,6 +29,8 @@ class MedgemmaLLM(MMChatLLM):
|
|||||||
).to(self.model.device, dtype=torch.bfloat16)
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
llm_register("/share/models/google/medgemma-4b-it", MedgemmaLLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
||||||
session = {}
|
session = {}
|
||||||
|
@ -7,7 +7,7 @@ from ahserver.serverenv import get_serverenv
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
|
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
||||||
|
|
||||||
class Qwen3LLM(T2TChatLLM):
|
class Qwen3LLM(T2TChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
@ -39,6 +39,8 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
)
|
)
|
||||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
|
llm_register("/share/models/Qwen/Qwen3-32B", Qwen3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
||||||
session = {}
|
session = {}
|
||||||
|
@ -2,9 +2,11 @@ from traceback import format_exc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
from llmengine.base_chat_llm import get_llm_class
|
||||||
from llmengine.gemma3_it import Gemma3LLM
|
from llmengine.gemma3_it import Gemma3LLM
|
||||||
from llmengine.qwen3 import Qwen3LLM
|
from llmengine.qwen3 import Qwen3LLM
|
||||||
from llmengine.medgemma3_it import MedgemmaLLM
|
from llmengine.medgemma3_it import MedgemmaLLM
|
||||||
|
from llmengine.devstral import DevstralLLM
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
@ -13,13 +15,6 @@ from ahserver.globalEnv import stream_response
|
|||||||
from ahserver.webapp import webserver
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
from aiohttp_session import get_session
|
from aiohttp_session import get_session
|
||||||
model_pathMap = {
|
|
||||||
"/share/models/Qwen/Qwen3-32B": Qwen3LLM,
|
|
||||||
"/share/models/google/gemma-3-4b-it": Gemma3LLM,
|
|
||||||
"/share/models/google/medgemma-4b-it": MedgemmaLLM
|
|
||||||
}
|
|
||||||
def register(model_path, Klass):
|
|
||||||
model_pathMap[model_path] = Klass
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
@ -51,7 +46,7 @@ def main():
|
|||||||
parser.add_argument('-p', '--port')
|
parser.add_argument('-p', '--port')
|
||||||
parser.add_argument('model_path')
|
parser.add_argument('model_path')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
Klass = model_pathMap.get(args.model_path)
|
Klass = get_llm_class(args.model_path)
|
||||||
if Klass is None:
|
if Klass is None:
|
||||||
e = Exception(f'{model_path} has not mapping to a model class')
|
e = Exception(f'{model_path} has not mapping to a model class')
|
||||||
exception(f'{e}, {format_exc()}')
|
exception(f'{e}, {format_exc()}')
|
||||||
|
@ -9,6 +9,7 @@ license = {text = "MIT"}
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"mistral-common",
|
||||||
"accelerate"
|
"accelerate"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user