diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index c6811da..cd4b6c2 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -17,45 +17,6 @@ class BaseChatLLM: key = self.get_session_key() session[key] = messages - def _build_assistant_message(self, prompt): - return { - "role":"assistant", - "content":[{"type": "text", "text": prompt}] - } - - def _build_sys_message(self, prompt): - return { - "role":"system", - "content":[{"type": "text", "text": prompt}] - } - - def _build_user_message(self, prompt, image_path=None, - video_path=None, audio_path=None): - contents = [ - { - "type":"text", "text": prompt - } - ] - if image_path: - contents.append({ - "type": "image", - "image": image_path - }) - if video_path: - contents.append({ - "type": "video", - "video":video_path - }) - if audio_path: - contents.append({ - "tyoe": "audio", - "audio": audio_path - }) - return { - "role": "user", - "content": contents - } - def get_streamer(self): return TextIteratorStreamer( tokenizer=self.tokenizer, @@ -186,3 +147,63 @@ class BaseChatLLM: # debug(f'{d=}\n') yield d +class T2TChatLLM(BaseChatLLM): + def _build_assistant_message(self, prompt): + return { + "role":"assistant", + "content":prompt + } + + def _build_sys_message(self, prompt): + return { + "role":"system", + "content": prompt + } + + def _build_user_message(self, prompt, **kw): + return { + "role":"user", + "content": prompt + } + +class MMChatLLM(BaseChatLLM): + """ multiple modal chat LLM """ + def _build_assistant_message(self, prompt): + return { + "role":"assistant", + "content":[{"type": "text", "text": prompt}] + } + + def _build_sys_message(self, prompt): + return { + "role":"system", + "content":[{"type": "text", "text": prompt}] + } + + def _build_user_message(self, prompt, image_path=None, + video_path=None, audio_path=None): + contents = [ + { + "type":"text", "text": prompt + } + ] + if image_path: + contents.append({ + "type": "image", + "image": image_path + }) + if video_path: + contents.append({ + "type": "video", + "video":video_path + }) + if audio_path: + contents.append({ + "tyoe": "audio", + "audio": audio_path + }) + return { + "role": "user", + "content": contents + } + diff --git a/llmengine/gemma3_it.py b/llmengine/gemma3_it.py index 47c02b6..382d17b 100644 --- a/llmengine/gemma3_it.py +++ b/llmengine/gemma3_it.py @@ -9,9 +9,9 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter from PIL import Image import requests import torch -from llmengine.base_chat_llm import BaseChatLLM +from llmengine.base_chat_llm import MMChatLLM -class Gemma3LLM(BaseChatLLM): +class Gemma3LLM(MMChatLLM): def __init__(self, model_id): self.model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto" diff --git a/llmengine/medgemma3_it.py b/llmengine/medgemma3_it.py index 7d8a71e..4aeec0b 100644 --- a/llmengine/medgemma3_it.py +++ b/llmengine/medgemma3_it.py @@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText from PIL import Image import requests import torch -from llmengine.base_chat_llm import BaseChatLLM +from llmengine.base_chat_llm import MMChatLLM model_id = "google/medgemma-4b-it" -class MedgemmaLLM(BaseChatLLM): +class MedgemmaLLM(MMChatLLM): def __init__(self, model_id): self.model = AutoModelForImageTextToText.from_pretrained( model_id, diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index d4ba20b..cd7a96f 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -6,9 +6,9 @@ from ahserver.serverenv import get_serverenv from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import torch -from llmengine.base_chat_llm import BaseChatLLM +from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM -class Qwen3LLM(BaseChatLLM): +class Qwen3LLM(T2TChatLLM): def __init__(self, model_id): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( @@ -18,25 +18,6 @@ class Qwen3LLM(BaseChatLLM): ) self.model_id = model_id - def _build_assistant_message(self, prompt): - return { - "role":"assistant", - "content":prompt - } - - def _build_sys_message(self, prompt): - return { - "role":"system", - "content": prompt - } - - def _build_user_message(self, prompt, **kw): - return { - "role":"user", - "content": prompt - } - - def build_kwargs(self, inputs, streamer): generate_kwargs = dict( **inputs, diff --git a/test/qwen3.sh b/test/qwen3.sh index 03090be..624fca5 100755 --- a/test/qwen3.sh +++ b/test/qwen3.sh @@ -1,3 +1,3 @@ #!/usr/bin/bash -CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 /share/vllm-0.8.5/bin/python -m llmengine.qwen3 +CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3