bugfix
This commit is contained in:
parent
789bc750a2
commit
9c72a83189
@ -17,45 +17,6 @@ class BaseChatLLM:
|
|||||||
key = self.get_session_key()
|
key = self.get_session_key()
|
||||||
session[key] = messages
|
session[key] = messages
|
||||||
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"assistant",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, image_path=None,
|
|
||||||
video_path=None, audio_path=None):
|
|
||||||
contents = [
|
|
||||||
{
|
|
||||||
"type":"text", "text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if image_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "image",
|
|
||||||
"image": image_path
|
|
||||||
})
|
|
||||||
if video_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "video",
|
|
||||||
"video":video_path
|
|
||||||
})
|
|
||||||
if audio_path:
|
|
||||||
contents.append({
|
|
||||||
"tyoe": "audio",
|
|
||||||
"audio": audio_path
|
|
||||||
})
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": contents
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_streamer(self):
|
def get_streamer(self):
|
||||||
return TextIteratorStreamer(
|
return TextIteratorStreamer(
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
@ -186,3 +147,63 @@ class BaseChatLLM:
|
|||||||
# debug(f'{d=}\n')
|
# debug(f'{d=}\n')
|
||||||
yield d
|
yield d
|
||||||
|
|
||||||
|
class T2TChatLLM(BaseChatLLM):
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, **kw):
|
||||||
|
return {
|
||||||
|
"role":"user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
class MMChatLLM(BaseChatLLM):
|
||||||
|
""" multiple modal chat LLM """
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, image_path=None,
|
||||||
|
video_path=None, audio_path=None):
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"type":"text", "text": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if image_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "image",
|
||||||
|
"image": image_path
|
||||||
|
})
|
||||||
|
if video_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "video",
|
||||||
|
"video":video_path
|
||||||
|
})
|
||||||
|
if audio_path:
|
||||||
|
contents.append({
|
||||||
|
"tyoe": "audio",
|
||||||
|
"audio": audio_path
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": contents
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM
|
from llmengine.base_chat_llm import MMChatLLM
|
||||||
|
|
||||||
class Gemma3LLM(BaseChatLLM):
|
class Gemma3LLM(MMChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
model_id, device_map="auto"
|
model_id, device_map="auto"
|
||||||
|
@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM
|
from llmengine.base_chat_llm import MMChatLLM
|
||||||
|
|
||||||
model_id = "google/medgemma-4b-it"
|
model_id = "google/medgemma-4b-it"
|
||||||
|
|
||||||
class MedgemmaLLM(BaseChatLLM):
|
class MedgemmaLLM(MMChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -6,9 +6,9 @@ from ahserver.serverenv import get_serverenv
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM
|
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM
|
||||||
|
|
||||||
class Qwen3LLM(BaseChatLLM):
|
class Qwen3LLM(T2TChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -18,25 +18,6 @@ class Qwen3LLM(BaseChatLLM):
|
|||||||
)
|
)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"assistant",
|
|
||||||
"content":prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, **kw):
|
|
||||||
return {
|
|
||||||
"role":"user",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
**inputs,
|
**inputs,
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#!/usr/bin/bash
|
#!/usr/bin/bash
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
||||||
|
Loading…
Reference in New Issue
Block a user