#!/share/vllm-0.8.5/bin/python # pip install accelerate import torch lfrom time import time from appPublic.worker import awaitify from ahserver.serverenv import get_serverenv from transformers import AutoProcessor, Gemma3ForConditionalGeneration from PIL import Image import requests import torch class Gemma3LLM: def __init__(self, model_id): self.model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto" ).eval() self.processor = AutoProcessor.from_pretrained(model_id) self.messages = [] self.model_id = model_id async def get_session_key(self): return self.model_id + ':messages' async def get_session_messages(self, request): f = get_serverenv('get_session') session = await f(request) key = self.get_session_key() messages = session.get(key) or [] return messages async def set_session_messages(self, request, messages): f = get_serverenv('get_session') session = await f(request) key = self.get_session_key() await session[key] = messages def _generate(self, request, prompt, image_path=None, sys_prompt=None): if sys_prompt: sys_message = self._build_sys_message(sys_prompt) self.messages.append(sys_message) user_message = self._build_user_message(prompt, image_path=image_path) self.messages.append(user_message) data = self._gen(self.messages) self.messages.append(self._build_assistant_message(data['text'])) def _build_assistant_message(self, prompt): return { "role":"assistant", "content":[{"type": "text", "text": prompt}] } def _build_sys_message(self, prompt): return { "role":"system", "content":[{"type": "text", "text": prompt}] } def _build_user_message(self, prompt, image_path=None): contents = [ { "type":"text", "text": prompt } ] if image_path: contents.append({ "type": "image", "image": image_path }) return { "role": "user", "content": contents } def _gen(self, messages): t1 = time() inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=1000, do_sample=True) generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) t2 = time() return { "role": "assistant", "input_tokens": input_len, "output_token": len(generation), "timecost": t2 - t1, "text": decoded } async def generate(self, request, prompt, image_path=None, sys_prompt=None): messages = self.get_session_messages(request) if sys_prompt and len(messages) == 0: messages.append(self._build_sys_message(sys_prompt)) messages.append(self._build_user_message(prompt, image_path=image_path)) f = awaitify(self._gen) data = await f(messages) messages.append(self._build_assistant_message(data['text'])) self.set_session_message(request, messages) if __name__ == '__main__': gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') while True: print('input prompt') p = input() if p: if p == 'q': break; print('input image path') imgpath=input() t = gemma3._generate(p, image_path=imgpath) print(t)