bugfix
This commit is contained in:
parent
789bc750a2
commit
9c72a83189
@ -17,45 +17,6 @@ class BaseChatLLM:
|
||||
key = self.get_session_key()
|
||||
session[key] = messages
|
||||
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, image_path=None,
|
||||
video_path=None, audio_path=None):
|
||||
contents = [
|
||||
{
|
||||
"type":"text", "text": prompt
|
||||
}
|
||||
]
|
||||
if image_path:
|
||||
contents.append({
|
||||
"type": "image",
|
||||
"image": image_path
|
||||
})
|
||||
if video_path:
|
||||
contents.append({
|
||||
"type": "video",
|
||||
"video":video_path
|
||||
})
|
||||
if audio_path:
|
||||
contents.append({
|
||||
"tyoe": "audio",
|
||||
"audio": audio_path
|
||||
})
|
||||
return {
|
||||
"role": "user",
|
||||
"content": contents
|
||||
}
|
||||
|
||||
def get_streamer(self):
|
||||
return TextIteratorStreamer(
|
||||
tokenizer=self.tokenizer,
|
||||
@ -186,3 +147,63 @@ class BaseChatLLM:
|
||||
# debug(f'{d=}\n')
|
||||
yield d
|
||||
|
||||
class T2TChatLLM(BaseChatLLM):
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":prompt
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, **kw):
|
||||
return {
|
||||
"role":"user",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
class MMChatLLM(BaseChatLLM):
|
||||
""" multiple modal chat LLM """
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, image_path=None,
|
||||
video_path=None, audio_path=None):
|
||||
contents = [
|
||||
{
|
||||
"type":"text", "text": prompt
|
||||
}
|
||||
]
|
||||
if image_path:
|
||||
contents.append({
|
||||
"type": "image",
|
||||
"image": image_path
|
||||
})
|
||||
if video_path:
|
||||
contents.append({
|
||||
"type": "video",
|
||||
"video":video_path
|
||||
})
|
||||
if audio_path:
|
||||
contents.append({
|
||||
"tyoe": "audio",
|
||||
"audio": audio_path
|
||||
})
|
||||
return {
|
||||
"role": "user",
|
||||
"content": contents
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM
|
||||
from llmengine.base_chat_llm import MMChatLLM
|
||||
|
||||
class Gemma3LLM(BaseChatLLM):
|
||||
class Gemma3LLM(MMChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
|
@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM
|
||||
from llmengine.base_chat_llm import MMChatLLM
|
||||
|
||||
model_id = "google/medgemma-4b-it"
|
||||
|
||||
class MedgemmaLLM(BaseChatLLM):
|
||||
class MedgemmaLLM(MMChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
|
@ -6,9 +6,9 @@ from ahserver.serverenv import get_serverenv
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from PIL import Image
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM
|
||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
|
||||
|
||||
class Qwen3LLM(BaseChatLLM):
|
||||
class Qwen3LLM(T2TChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -18,25 +18,6 @@ class Qwen3LLM(BaseChatLLM):
|
||||
)
|
||||
self.model_id = model_id
|
||||
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":prompt
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, **kw):
|
||||
return {
|
||||
"role":"user",
|
||||
"content": prompt
|
||||
}
|
||||
|
||||
|
||||
def build_kwargs(self, inputs, streamer):
|
||||
generate_kwargs = dict(
|
||||
**inputs,
|
||||
|
@ -1,3 +1,3 @@
|
||||
#!/usr/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
||||
CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
||||
|
Loading…
Reference in New Issue
Block a user