llmengine/test/gemma-3-4b-it
2025-06-06 08:48:37 +00:00

120 lines
3.3 KiB
Python
Executable File

#!/share/vllm-0.8.5/bin/python
# pip install accelerate
import torch
lfrom time import time
from appPublic.worker import awaitify
from ahserver.serverenv import get_serverenv
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
class Gemma3LLM:
def __init__(self, model_id):
self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
self.processor = AutoProcessor.from_pretrained(model_id)
self.messages = []
self.model_id = model_id
async def get_session_key(self):
return self.model_id + ':messages'
async def get_session_messages(self, request):
f = get_serverenv('get_session')
session = await f(request)
key = self.get_session_key()
messages = session.get(key) or []
return messages
async def set_session_messages(self, request, messages):
f = get_serverenv('get_session')
session = await f(request)
key = self.get_session_key()
await session[key] = messages
def _generate(self, request, prompt, image_path=None, sys_prompt=None):
if sys_prompt:
sys_message = self._build_sys_message(sys_prompt)
self.messages.append(sys_message)
user_message = self._build_user_message(prompt, image_path=image_path)
self.messages.append(user_message)
data = self._gen(self.messages)
self.messages.append(self._build_assistant_message(data['text']))
def _build_assistant_message(self, prompt):
return {
"role":"assistant",
"content":[{"type": "text", "text": prompt}]
}
def _build_sys_message(self, prompt):
return {
"role":"system",
"content":[{"type": "text", "text": prompt}]
}
def _build_user_message(self, prompt, image_path=None):
contents = [
{
"type":"text", "text": prompt
}
]
if image_path:
contents.append({
"type": "image",
"image": image_path
})
return {
"role": "user",
"content": contents
}
def _gen(self, messages):
t1 = time()
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True,
tokenize=True,
return_dict=True, return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = self.model.generate(**inputs, max_new_tokens=1000, do_sample=True)
generation = generation[0][input_len:]
decoded = self.processor.decode(generation, skip_special_tokens=True)
t2 = time()
return {
"role": "assistant",
"input_tokens": input_len,
"output_token": len(generation),
"timecost": t2 - t1,
"text": decoded
}
async def generate(self, request, prompt, image_path=None, sys_prompt=None):
messages = self.get_session_messages(request)
if sys_prompt and len(messages) == 0:
messages.append(self._build_sys_message(sys_prompt))
messages.append(self._build_user_message(prompt, image_path=image_path))
f = awaitify(self._gen)
data = await f(messages)
messages.append(self._build_assistant_message(data['text']))
self.set_session_message(request, messages)
if __name__ == '__main__':
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
while True:
print('input prompt')
p = input()
if p:
if p == 'q':
break;
print('input image path')
imgpath=input()
t = gemma3._generate(p, image_path=imgpath)
print(t)