This commit is contained in:
yumoqing 2025-06-09 07:04:12 +00:00
parent 789bc750a2
commit 9c72a83189
5 changed files with 67 additions and 65 deletions

View File

@ -17,45 +17,6 @@ class BaseChatLLM:
key = self.get_session_key() key = self.get_session_key()
session[key] = messages session[key] = messages
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None,
video_path=None, audio_path=None):
contents = [
{
"type":"text", "text": prompt
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
}
def get_streamer(self): def get_streamer(self):
return TextIteratorStreamer( return TextIteratorStreamer(
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
@ -186,3 +147,63 @@ class BaseChatLLM:
# debug(f'{d=}\n') # debug(f'{d=}\n')
yield d yield d
class T2TChatLLM(BaseChatLLM):
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":prompt
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content": prompt
}
def _build_user_message(self, prompt, **kw):
return {
"role":"user",
"content": prompt
}
class MMChatLLM(BaseChatLLM):
""" multiple modal chat LLM """
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None,
video_path=None, audio_path=None):
contents = [
{
"type":"text", "text": prompt
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
if video_path:
contents.append({
"type": "video",
"video":video_path
})
if audio_path:
contents.append({
"tyoe": "audio",
"audio": audio_path
})
return {
"role": "user",
"content": contents
}

View File

@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM from llmengine.base_chat_llm import MMChatLLM
class Gemma3LLM(BaseChatLLM): class Gemma3LLM(MMChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained( self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto" model_id, device_map="auto"

View File

@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image from PIL import Image
import requests import requests
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM from llmengine.base_chat_llm import MMChatLLM
model_id = "google/medgemma-4b-it" model_id = "google/medgemma-4b-it"
class MedgemmaLLM(BaseChatLLM): class MedgemmaLLM(MMChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained( self.model = AutoModelForImageTextToText.from_pretrained(
model_id, model_id,

View File

@ -6,9 +6,9 @@ from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
class Qwen3LLM(BaseChatLLM): class Qwen3LLM(T2TChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
@ -18,25 +18,6 @@ class Qwen3LLM(BaseChatLLM):
) )
self.model_id = model_id self.model_id = model_id
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":prompt
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content": prompt
}
def _build_user_message(self, prompt, **kw):
return {
"role":"user",
"content": prompt
}
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
generate_kwargs = dict( generate_kwargs = dict(
**inputs, **inputs,

View File

@ -1,3 +1,3 @@
#!/usr/bin/bash #!/usr/bin/bash
CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 /share/vllm-0.8.5/bin/python -m llmengine.qwen3 CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3