llmengine/llmengine/medgemma3_it.py
2025-06-09 07:04:12 +00:00

52 lines
1.3 KiB
Python

# pip install accelerate
import time
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
import torch
from llmengine.base_chat_llm import MMChatLLM
model_id = "google/medgemma-4b-it"
class MedgemmaLLM(MMChatLLM):
def __init__(self, model_id):
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(model_id)
self.tokenizer = self.processor.tokenizer
self.model_id = model_id
def _messages2inputs(self, messages):
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
return inputs
if __name__ == '__main__':
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
session = {}
while True:
print(f'chat with {med.model_id}')
print('input prompt')
p = input()
if p:
if p == 'q':
break;
print('input image path')
imgpath=input()
for d in med.stream_generate(session, p, image_path=imgpath):
if not d['done']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n')