bugfix
This commit is contained in:
parent
8d812bf42f
commit
789bc750a2
@ -1,23 +1,61 @@
|
|||||||
|
import threading
|
||||||
|
import torch
|
||||||
from time import time
|
from time import time
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
from appPublic.log import debug
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
async def get_session_key(self):
|
def get_session_key(self):
|
||||||
return self.model_id + ':messages'
|
return self.model_id + ':messages'
|
||||||
|
|
||||||
async def get_session_messages(self, request):
|
def _get_session_messages(self, session):
|
||||||
f = get_serverenv('get_session')
|
|
||||||
session = await f(request)
|
|
||||||
key = self.get_session_key()
|
key = self.get_session_key()
|
||||||
messages = session.get(key) or []
|
messages = session.get(key) or []
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def set_session_messages(self, request, messages):
|
def _set_session_messages(self, session, messages):
|
||||||
f = get_serverenv('get_session')
|
|
||||||
session = await f(request)
|
|
||||||
key = self.get_session_key()
|
key = self.get_session_key()
|
||||||
session[key] = messages
|
session[key] = messages
|
||||||
|
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, image_path=None,
|
||||||
|
video_path=None, audio_path=None):
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"type":"text", "text": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if image_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "image",
|
||||||
|
"image": image_path
|
||||||
|
})
|
||||||
|
if video_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "video",
|
||||||
|
"video":video_path
|
||||||
|
})
|
||||||
|
if audio_path:
|
||||||
|
contents.append({
|
||||||
|
"tyoe": "audio",
|
||||||
|
"audio": audio_path
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": contents
|
||||||
|
}
|
||||||
|
|
||||||
def get_streamer(self):
|
def get_streamer(self):
|
||||||
return TextIteratorStreamer(
|
return TextIteratorStreamer(
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
@ -30,22 +68,121 @@ class BaseChatLLM:
|
|||||||
t1 = time()
|
t1 = time()
|
||||||
i = 0
|
i = 0
|
||||||
for txt in streamer:
|
for txt in streamer:
|
||||||
|
if txt == '':
|
||||||
|
continue
|
||||||
if i == 0:
|
if i == 0:
|
||||||
t2 = time()
|
t2 = time()
|
||||||
i += 1
|
i += 1
|
||||||
|
all_txt += txt
|
||||||
yield {
|
yield {
|
||||||
'done': False,
|
'done': False,
|
||||||
'text': txt
|
'text': txt
|
||||||
}
|
}
|
||||||
t3 = time()
|
t3 = time()
|
||||||
unk = self.tokenizer(all_txt, return_tensors="pt")
|
t = all_txt
|
||||||
print(f'{unk=};')
|
unk = self.tokenizer(t, return_tensors="pt")
|
||||||
output_tokens = len(unk["input_ids"][0])
|
output_tokens = len(unk["input_ids"][0])
|
||||||
yield {
|
d = {
|
||||||
'done': True,
|
'done': True,
|
||||||
'text': '',
|
'text': all_txt,
|
||||||
'response_time': t2 - t1,
|
'response_time': t2 - t1,
|
||||||
'finish_time': t3 - t1,
|
'finish_time': t3 - t1,
|
||||||
'output_token': output_tokens
|
'output_token': output_tokens
|
||||||
}
|
}
|
||||||
|
# debug(f'{all_txt=}, {d=}')
|
||||||
|
yield d
|
||||||
|
|
||||||
|
def _generator(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
messages = self._get_session_messages(session)
|
||||||
|
if sys_prompt:
|
||||||
|
messages.append(self._build_sys_message(sys_prompt))
|
||||||
|
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||||
|
# debug(f'{messages=}')
|
||||||
|
for d in self._gen(messages):
|
||||||
|
if d['done']:
|
||||||
|
# debug(f'++++++++++++++{d=}')
|
||||||
|
messages.append(self._build_assistant_message(d['text']))
|
||||||
|
yield d
|
||||||
|
self._set_session_messages(session, messages)
|
||||||
|
|
||||||
|
def generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
for d in self._generator(session, prompt,
|
||||||
|
image_path=image_path,
|
||||||
|
video_path=video_path,
|
||||||
|
audio_path=audio_path,
|
||||||
|
sys_prompt=sys_prompt):
|
||||||
|
if d['done']:
|
||||||
|
return d
|
||||||
|
def stream_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
for d in self._generator(session, prompt,
|
||||||
|
image_path=image_path,
|
||||||
|
video_path=video_path,
|
||||||
|
audio_path=audio_path,
|
||||||
|
sys_prompt=sys_prompt):
|
||||||
|
yield d
|
||||||
|
|
||||||
|
async def async_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
return self.generate(session, prompt,
|
||||||
|
image_path=image_path,
|
||||||
|
video_path=video_path,
|
||||||
|
audio_path=audio_path,
|
||||||
|
sys_prompt=sys_prompt)
|
||||||
|
async def async_stream_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
for d in self._generator(session, prompt,
|
||||||
|
image_path=image_path,
|
||||||
|
video_path=video_path,
|
||||||
|
audio_path=audio_path,
|
||||||
|
sys_prompt=sys_prompt):
|
||||||
|
yield d
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=512,
|
||||||
|
do_sample=True,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
return self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True, return_tensors="pt"
|
||||||
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def _gen(self, messages):
|
||||||
|
inputs = self._messages2inputs(messages)
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
streamer = self.get_streamer()
|
||||||
|
kwargs = self.build_kwargs(inputs, streamer)
|
||||||
|
thread = threading.Thread(target=self.model.generate,
|
||||||
|
kwargs=kwargs)
|
||||||
|
thread.start()
|
||||||
|
for d in self.output_generator(streamer):
|
||||||
|
if d['done']:
|
||||||
|
d['input_tokens'] = input_len
|
||||||
|
# debug(f'{d=}\n')
|
||||||
|
yield d
|
||||||
|
|
||||||
|
@ -21,105 +21,9 @@ class Gemma3LLM(BaseChatLLM):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"assistant",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, image_path=None):
|
|
||||||
contents = [
|
|
||||||
{
|
|
||||||
"type":"text", "text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if image_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "image",
|
|
||||||
"image": image_path
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": contents
|
|
||||||
}
|
|
||||||
|
|
||||||
def _gen(self, messages):
|
|
||||||
t1 = time()
|
|
||||||
inputs = self.processor.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True, return_tensors="pt"
|
|
||||||
).to(self.model.device, dtype=torch.bfloat16)
|
|
||||||
input_len = inputs["input_ids"].shape[-1]
|
|
||||||
streamer = self.get_streamer()
|
|
||||||
generate_kwargs = dict(
|
|
||||||
**inputs,
|
|
||||||
streamer=streamer,
|
|
||||||
max_new_tokens=512,
|
|
||||||
do_sample=True,
|
|
||||||
eos_token_id=self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
thread = threading.Thread(target=self.model.generate,
|
|
||||||
kwargs=generate_kwargs)
|
|
||||||
thread.start()
|
|
||||||
for d in self.output_generator(streamer):
|
|
||||||
if d['done']:
|
|
||||||
d['input_tokens'] = input_len
|
|
||||||
yield d
|
|
||||||
|
|
||||||
async def generate(self, request, prompt,
|
|
||||||
image_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
messages = self.get_session_messages(request)
|
|
||||||
if sys_prompt and len(messages) == 0:
|
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
|
||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
||||||
all_txt = ''
|
|
||||||
for d in self._gen(messages):
|
|
||||||
all_txt += d['text']
|
|
||||||
d['text'] = all_txt
|
|
||||||
messages.append(self._build_assistant_message(all_txt))
|
|
||||||
self.set_session_message(request, messages)
|
|
||||||
return d
|
|
||||||
|
|
||||||
async def strem_generate(self, request, prompt,
|
|
||||||
image_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
messages = self.get_session_messages(request)
|
|
||||||
if sys_prompt and len(messages) == 0:
|
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
|
||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
||||||
all_txt = ''
|
|
||||||
for d in self._gen(messages):
|
|
||||||
yield d
|
|
||||||
all_txt += d['text']
|
|
||||||
data = await f(messages)
|
|
||||||
messages.append(self._build_assistant_message(all_txt))
|
|
||||||
self.set_session_messages(request, messages)
|
|
||||||
|
|
||||||
def _generate(self, prompt, image_path=None, sys_prompt=None):
|
|
||||||
messages = self.messages
|
|
||||||
if sys_prompt and len(messages) == 0:
|
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
|
||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
||||||
all_txt = ''
|
|
||||||
ld = None
|
|
||||||
for d in self._gen(messages):
|
|
||||||
all_txt += d['text']
|
|
||||||
ld = d
|
|
||||||
ld['text'] = all_txt
|
|
||||||
messages.append(self._build_assistant_message(all_txt))
|
|
||||||
return ld
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
|
session = {}
|
||||||
while True:
|
while True:
|
||||||
print('input prompt')
|
print('input prompt')
|
||||||
p = input()
|
p = input()
|
||||||
@ -128,6 +32,11 @@ if __name__ == '__main__':
|
|||||||
break;
|
break;
|
||||||
print('input image path')
|
print('input image path')
|
||||||
imgpath=input()
|
imgpath=input()
|
||||||
t = gemma3._generate(p, image_path=imgpath)
|
for d in gemma3.stream_generate(session, p, image_path=imgpath):
|
||||||
print(t)
|
if not d['done']:
|
||||||
|
print(d['text'], end='', flush=True)
|
||||||
|
else:
|
||||||
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
print(f'\n{x}\n')
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,11 +4,11 @@ from transformers import AutoProcessor, AutoModelForImageTextToText
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM
|
||||||
|
|
||||||
model_id = "google/medgemma-4b-it"
|
model_id = "google/medgemma-4b-it"
|
||||||
|
|
||||||
class MedgemmaLLM:
|
class MedgemmaLLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
@ -16,37 +16,36 @@ class MedgemmaLLM:
|
|||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
# Image attribution: Stillwaterising, CC0, via Wikimedia Commons
|
def _messages2inputs(self, messages):
|
||||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
|
inputs = self.processor.apply_chat_template(
|
||||||
image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw)
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
|
return inputs
|
||||||
|
|
||||||
messages = [
|
if __name__ == '__main__':
|
||||||
{
|
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
||||||
"role": "system",
|
session = {}
|
||||||
"content": [{"type": "text", "text": "You are an expert radiologist."}]
|
while True:
|
||||||
},
|
print(f'chat with {med.model_id}')
|
||||||
{
|
print('input prompt')
|
||||||
"role": "user",
|
p = input()
|
||||||
"content": [
|
if p:
|
||||||
{"type": "text", "text": "Describe this X-ray"},
|
if p == 'q':
|
||||||
{"type": "image", "image": image}
|
break;
|
||||||
]
|
print('input image path')
|
||||||
}
|
imgpath=input()
|
||||||
]
|
for d in med.stream_generate(session, p, image_path=imgpath):
|
||||||
|
if not d['done']:
|
||||||
inputs = processor.apply_chat_template(
|
print(d['text'], end='', flush=True)
|
||||||
messages, add_generation_prompt=True, tokenize=True,
|
else:
|
||||||
return_dict=True, return_tensors="pt"
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
).to(model.device, dtype=torch.bfloat16)
|
print(f'\n{x}\n')
|
||||||
|
|
||||||
input_len = inputs["input_ids"].shape[-1]
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
|
||||||
generation = generation[0][input_len:]
|
|
||||||
|
|
||||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
|
||||||
print(decoded)
|
|
||||||
|
|
||||||
|
75
llmengine/qwen3.py
Normal file
75
llmengine/qwen3.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
|
||||||
|
# pip install accelerate
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM
|
||||||
|
|
||||||
|
class Qwen3LLM(BaseChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, **kw):
|
||||||
|
return {
|
||||||
|
"role":"user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=32768,
|
||||||
|
do_sample=True,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
text = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True
|
||||||
|
)
|
||||||
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
||||||
|
session = {}
|
||||||
|
while True:
|
||||||
|
print('input prompt')
|
||||||
|
p = input()
|
||||||
|
if p:
|
||||||
|
if p == 'q':
|
||||||
|
break;
|
||||||
|
for d in q3.stream_generate(session, p):
|
||||||
|
if not d['done']:
|
||||||
|
print(d['text'], end='', flush=True)
|
||||||
|
else:
|
||||||
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
print(f'\n{x}\n')
|
||||||
|
|
||||||
|
|
0
test/ds-r1-8b
Executable file
0
test/ds-r1-8b
Executable file
119
test/gemma-3-4b-it
Executable file
119
test/gemma-3-4b-it
Executable file
@ -0,0 +1,119 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
|
||||||
|
# pip install accelerate
|
||||||
|
import torch
|
||||||
|
lfrom time import time
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Gemma3LLM:
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, device_map="auto"
|
||||||
|
).eval()
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self.messages = []
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
async def get_session_key(self):
|
||||||
|
return self.model_id + ':messages'
|
||||||
|
|
||||||
|
async def get_session_messages(self, request):
|
||||||
|
f = get_serverenv('get_session')
|
||||||
|
session = await f(request)
|
||||||
|
key = self.get_session_key()
|
||||||
|
messages = session.get(key) or []
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def set_session_messages(self, request, messages):
|
||||||
|
f = get_serverenv('get_session')
|
||||||
|
session = await f(request)
|
||||||
|
key = self.get_session_key()
|
||||||
|
await session[key] = messages
|
||||||
|
|
||||||
|
def _generate(self, request, prompt, image_path=None, sys_prompt=None):
|
||||||
|
if sys_prompt:
|
||||||
|
sys_message = self._build_sys_message(sys_prompt)
|
||||||
|
self.messages.append(sys_message)
|
||||||
|
user_message = self._build_user_message(prompt, image_path=image_path)
|
||||||
|
self.messages.append(user_message)
|
||||||
|
data = self._gen(self.messages)
|
||||||
|
self.messages.append(self._build_assistant_message(data['text']))
|
||||||
|
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, image_path=None):
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"type":"text", "text": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if image_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "image",
|
||||||
|
"image": image_path
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": contents
|
||||||
|
}
|
||||||
|
|
||||||
|
def _gen(self, messages):
|
||||||
|
t1 = time()
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True, return_tensors="pt"
|
||||||
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
with torch.inference_mode():
|
||||||
|
generation = self.model.generate(**inputs, max_new_tokens=1000, do_sample=True)
|
||||||
|
generation = generation[0][input_len:]
|
||||||
|
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
||||||
|
t2 = time()
|
||||||
|
return {
|
||||||
|
"role": "assistant",
|
||||||
|
"input_tokens": input_len,
|
||||||
|
"output_token": len(generation),
|
||||||
|
"timecost": t2 - t1,
|
||||||
|
"text": decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate(self, request, prompt, image_path=None, sys_prompt=None):
|
||||||
|
messages = self.get_session_messages(request)
|
||||||
|
if sys_prompt and len(messages) == 0:
|
||||||
|
messages.append(self._build_sys_message(sys_prompt))
|
||||||
|
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||||
|
f = awaitify(self._gen)
|
||||||
|
data = await f(messages)
|
||||||
|
messages.append(self._build_assistant_message(data['text']))
|
||||||
|
self.set_session_message(request, messages)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
|
while True:
|
||||||
|
print('input prompt')
|
||||||
|
p = input()
|
||||||
|
if p:
|
||||||
|
if p == 'q':
|
||||||
|
break;
|
||||||
|
print('input image path')
|
||||||
|
imgpath=input()
|
||||||
|
t = gemma3._generate(p, image_path=imgpath)
|
||||||
|
print(t)
|
||||||
|
|
3
test/gemma3.sh
Executable file
3
test/gemma3.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 /share/vllm-0.8.5/bin/python -m llmengine.gemma3_it
|
3
test/medgemma3.sh
Executable file
3
test/medgemma3.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 /share/vllm-0.8.5/bin/python -m llmengine.medgemma3_it
|
3
test/qwen3.sh
Executable file
3
test/qwen3.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
Loading…
Reference in New Issue
Block a user