120 lines
3.3 KiB
Python
Executable File
120 lines
3.3 KiB
Python
Executable File
#!/share/vllm-0.8.5/bin/python
|
|
|
|
# pip install accelerate
|
|
import torch
|
|
lfrom time import time
|
|
from appPublic.worker import awaitify
|
|
from ahserver.serverenv import get_serverenv
|
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
from PIL import Image
|
|
import requests
|
|
import torch
|
|
|
|
class Gemma3LLM:
|
|
def __init__(self, model_id):
|
|
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
model_id, device_map="auto"
|
|
).eval()
|
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
self.messages = []
|
|
self.model_id = model_id
|
|
|
|
async def get_session_key(self):
|
|
return self.model_id + ':messages'
|
|
|
|
async def get_session_messages(self, request):
|
|
f = get_serverenv('get_session')
|
|
session = await f(request)
|
|
key = self.get_session_key()
|
|
messages = session.get(key) or []
|
|
return messages
|
|
|
|
async def set_session_messages(self, request, messages):
|
|
f = get_serverenv('get_session')
|
|
session = await f(request)
|
|
key = self.get_session_key()
|
|
await session[key] = messages
|
|
|
|
def _generate(self, request, prompt, image_path=None, sys_prompt=None):
|
|
if sys_prompt:
|
|
sys_message = self._build_sys_message(sys_prompt)
|
|
self.messages.append(sys_message)
|
|
user_message = self._build_user_message(prompt, image_path=image_path)
|
|
self.messages.append(user_message)
|
|
data = self._gen(self.messages)
|
|
self.messages.append(self._build_assistant_message(data['text']))
|
|
|
|
def _build_assistant_message(self, prompt):
|
|
return {
|
|
"role":"assistant",
|
|
"content":[{"type": "text", "text": prompt}]
|
|
}
|
|
|
|
def _build_sys_message(self, prompt):
|
|
return {
|
|
"role":"system",
|
|
"content":[{"type": "text", "text": prompt}]
|
|
}
|
|
|
|
def _build_user_message(self, prompt, image_path=None):
|
|
contents = [
|
|
{
|
|
"type":"text", "text": prompt
|
|
}
|
|
]
|
|
if image_path:
|
|
contents.append({
|
|
"type": "image",
|
|
"image": image_path
|
|
})
|
|
|
|
return {
|
|
"role": "user",
|
|
"content": contents
|
|
}
|
|
|
|
def _gen(self, messages):
|
|
t1 = time()
|
|
inputs = self.processor.apply_chat_template(
|
|
messages, add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True, return_tensors="pt"
|
|
).to(self.model.device, dtype=torch.bfloat16)
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
with torch.inference_mode():
|
|
generation = self.model.generate(**inputs, max_new_tokens=1000, do_sample=True)
|
|
generation = generation[0][input_len:]
|
|
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
|
t2 = time()
|
|
return {
|
|
"role": "assistant",
|
|
"input_tokens": input_len,
|
|
"output_token": len(generation),
|
|
"timecost": t2 - t1,
|
|
"text": decoded
|
|
}
|
|
|
|
async def generate(self, request, prompt, image_path=None, sys_prompt=None):
|
|
messages = self.get_session_messages(request)
|
|
if sys_prompt and len(messages) == 0:
|
|
messages.append(self._build_sys_message(sys_prompt))
|
|
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
f = awaitify(self._gen)
|
|
data = await f(messages)
|
|
messages.append(self._build_assistant_message(data['text']))
|
|
self.set_session_message(request, messages)
|
|
|
|
if __name__ == '__main__':
|
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
|
while True:
|
|
print('input prompt')
|
|
p = input()
|
|
if p:
|
|
if p == 'q':
|
|
break;
|
|
print('input image path')
|
|
imgpath=input()
|
|
t = gemma3._generate(p, image_path=imgpath)
|
|
print(t)
|
|
|