From 8d812bf42f3f0d43965bff7385f29a30c2273b5b Mon Sep 17 00:00:00 2001 From: yumoqing Date: Thu, 5 Jun 2025 15:15:29 +0000 Subject: [PATCH] bugfix --- llmengine/base_chat_llm.py | 51 ++++++++++++++ llmengine/chatllm.py | 2 +- llmengine/gemma3_it.py | 133 +++++++++++++++++++++++++++++++++++++ llmengine/medgemma3_it.py | 52 +++++++++++++++ test/chatllm | 2 +- 5 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 llmengine/base_chat_llm.py create mode 100644 llmengine/gemma3_it.py create mode 100644 llmengine/medgemma3_it.py diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py new file mode 100644 index 0000000..9d8f0c1 --- /dev/null +++ b/llmengine/base_chat_llm.py @@ -0,0 +1,51 @@ +from time import time +from transformers import TextIteratorStreamer + +class BaseChatLLM: + async def get_session_key(self): + return self.model_id + ':messages' + + async def get_session_messages(self, request): + f = get_serverenv('get_session') + session = await f(request) + key = self.get_session_key() + messages = session.get(key) or [] + return messages + + async def set_session_messages(self, request, messages): + f = get_serverenv('get_session') + session = await f(request) + key = self.get_session_key() + session[key] = messages + + def get_streamer(self): + return TextIteratorStreamer( + tokenizer=self.tokenizer, + skip_special_tokens=True, + skip_prompt=True + ) + + def output_generator(self, streamer): + all_txt = '' + t1 = time() + i = 0 + for txt in streamer: + if i == 0: + t2 = time() + i += 1 + yield { + 'done': False, + 'text': txt + } + t3 = time() + unk = self.tokenizer(all_txt, return_tensors="pt") + print(f'{unk=};') + output_tokens = len(unk["input_ids"][0]) + yield { + 'done': True, + 'text': '', + 'response_time': t2 - t1, + 'finish_time': t3 - t1, + 'output_token': output_tokens + } + diff --git a/llmengine/chatllm.py b/llmengine/chatllm.py index 3aaf3c6..1d6f683 100644 --- a/llmengine/chatllm.py +++ b/llmengine/chatllm.py @@ -106,7 +106,7 @@ class TransformersChatEngine: if not self.output_json: return text input_tokens = inputs["input_ids"].shape[1] - outputi_ids.sequences.shape[1] - input_tokens + output_tokens = len(self.tokenizer(text, return_tensors="pt")["input_ids"][0]) return { 'content':text, 'input_tokens': input_tokens, diff --git a/llmengine/gemma3_it.py b/llmengine/gemma3_it.py new file mode 100644 index 0000000..1dbb2dd --- /dev/null +++ b/llmengine/gemma3_it.py @@ -0,0 +1,133 @@ +#!/share/vllm-0.8.5/bin/python + +# pip install accelerate +import threading +from time import time +from appPublic.worker import awaitify +from ahserver.serverenv import get_serverenv +from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer +from PIL import Image +import requests +import torch +from llmengine.base_chat_llm import BaseChatLLM + +class Gemma3LLM(BaseChatLLM): + def __init__(self, model_id): + self.model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, device_map="auto" + ).eval() + self.processor = AutoProcessor.from_pretrained(model_id) + self.tokenizer = self.processor.tokenizer + self.messages = [] + self.model_id = model_id + + def _build_assistant_message(self, prompt): + return { + "role":"assistant", + "content":[{"type": "text", "text": prompt}] + } + + def _build_sys_message(self, prompt): + return { + "role":"system", + "content":[{"type": "text", "text": prompt}] + } + + def _build_user_message(self, prompt, image_path=None): + contents = [ + { + "type":"text", "text": prompt + } + ] + if image_path: + contents.append({ + "type": "image", + "image": image_path + }) + + return { + "role": "user", + "content": contents + } + + def _gen(self, messages): + t1 = time() + inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, + tokenize=True, + return_dict=True, return_tensors="pt" + ).to(self.model.device, dtype=torch.bfloat16) + input_len = inputs["input_ids"].shape[-1] + streamer = self.get_streamer() + generate_kwargs = dict( + **inputs, + streamer=streamer, + max_new_tokens=512, + do_sample=True, + eos_token_id=self.tokenizer.eos_token_id + ) + thread = threading.Thread(target=self.model.generate, + kwargs=generate_kwargs) + thread.start() + for d in self.output_generator(streamer): + if d['done']: + d['input_tokens'] = input_len + yield d + + async def generate(self, request, prompt, + image_path=None, + sys_prompt=None): + messages = self.get_session_messages(request) + if sys_prompt and len(messages) == 0: + messages.append(self._build_sys_message(sys_prompt)) + messages.append(self._build_user_message(prompt, image_path=image_path)) + all_txt = '' + for d in self._gen(messages): + all_txt += d['text'] + d['text'] = all_txt + messages.append(self._build_assistant_message(all_txt)) + self.set_session_message(request, messages) + return d + + async def strem_generate(self, request, prompt, + image_path=None, + sys_prompt=None): + messages = self.get_session_messages(request) + if sys_prompt and len(messages) == 0: + messages.append(self._build_sys_message(sys_prompt)) + messages.append(self._build_user_message(prompt, image_path=image_path)) + all_txt = '' + for d in self._gen(messages): + yield d + all_txt += d['text'] + data = await f(messages) + messages.append(self._build_assistant_message(all_txt)) + self.set_session_messages(request, messages) + + def _generate(self, prompt, image_path=None, sys_prompt=None): + messages = self.messages + if sys_prompt and len(messages) == 0: + messages.append(self._build_sys_message(sys_prompt)) + messages.append(self._build_user_message(prompt, image_path=image_path)) + all_txt = '' + ld = None + for d in self._gen(messages): + all_txt += d['text'] + ld = d + ld['text'] = all_txt + messages.append(self._build_assistant_message(all_txt)) + return ld + +if __name__ == '__main__': + gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it') + while True: + print('input prompt') + p = input() + if p: + if p == 'q': + break; + print('input image path') + imgpath=input() + t = gemma3._generate(p, image_path=imgpath) + print(t) + diff --git a/llmengine/medgemma3_it.py b/llmengine/medgemma3_it.py new file mode 100644 index 0000000..985265a --- /dev/null +++ b/llmengine/medgemma3_it.py @@ -0,0 +1,52 @@ +# pip install accelerate +import time +from transformers import AutoProcessor, AutoModelForImageTextToText +from PIL import Image +import requests +import torch + + +model_id = "google/medgemma-4b-it" + +class MedgemmaLLM: + def __init__(self, model_id): + self.model = AutoModelForImageTextToText.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + self.processor = AutoProcessor.from_pretrained(model_id) + self.model_id = model_id + +# Image attribution: Stillwaterising, CC0, via Wikimedia Commons +image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png" +image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw) + +messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are an expert radiologist."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this X-ray"}, + {"type": "image", "image": image} + ] + } +] + +inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + return_dict=True, return_tensors="pt" +).to(model.device, dtype=torch.bfloat16) + +input_len = inputs["input_ids"].shape[-1] + +with torch.inference_mode(): + generation = model.generate(**inputs, max_new_tokens=200, do_sample=False) + generation = generation[0][input_len:] + +decoded = processor.decode(generation, skip_special_tokens=True) +print(decoded) + diff --git a/test/chatllm b/test/chatllm index 3a7c531..9b48dff 100755 --- a/test/chatllm +++ b/test/chatllm @@ -6,7 +6,7 @@ import argparse def get_args(): parser = argparse.ArgumentParser(description="Example script using argparse") parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split') - parser.add_argument("--stream", action="store_true", help="是否流式输出") + parser.add_argument("--stream", action="store_true", help="是否流式输出", default=True) parser.add_argument('modelpath', type=str, help='Path to model folder') args = parser.parse_args() return args