bugfix
This commit is contained in:
parent
b8104f5ca1
commit
8d812bf42f
51
llmengine/base_chat_llm.py
Normal file
51
llmengine/base_chat_llm.py
Normal file
@ -0,0 +1,51 @@
|
||||
from time import time
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
class BaseChatLLM:
|
||||
async def get_session_key(self):
|
||||
return self.model_id + ':messages'
|
||||
|
||||
async def get_session_messages(self, request):
|
||||
f = get_serverenv('get_session')
|
||||
session = await f(request)
|
||||
key = self.get_session_key()
|
||||
messages = session.get(key) or []
|
||||
return messages
|
||||
|
||||
async def set_session_messages(self, request, messages):
|
||||
f = get_serverenv('get_session')
|
||||
session = await f(request)
|
||||
key = self.get_session_key()
|
||||
session[key] = messages
|
||||
|
||||
def get_streamer(self):
|
||||
return TextIteratorStreamer(
|
||||
tokenizer=self.tokenizer,
|
||||
skip_special_tokens=True,
|
||||
skip_prompt=True
|
||||
)
|
||||
|
||||
def output_generator(self, streamer):
|
||||
all_txt = ''
|
||||
t1 = time()
|
||||
i = 0
|
||||
for txt in streamer:
|
||||
if i == 0:
|
||||
t2 = time()
|
||||
i += 1
|
||||
yield {
|
||||
'done': False,
|
||||
'text': txt
|
||||
}
|
||||
t3 = time()
|
||||
unk = self.tokenizer(all_txt, return_tensors="pt")
|
||||
print(f'{unk=};')
|
||||
output_tokens = len(unk["input_ids"][0])
|
||||
yield {
|
||||
'done': True,
|
||||
'text': '',
|
||||
'response_time': t2 - t1,
|
||||
'finish_time': t3 - t1,
|
||||
'output_token': output_tokens
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ class TransformersChatEngine:
|
||||
if not self.output_json:
|
||||
return text
|
||||
input_tokens = inputs["input_ids"].shape[1]
|
||||
outputi_ids.sequences.shape[1] - input_tokens
|
||||
output_tokens = len(self.tokenizer(text, return_tensors="pt")["input_ids"][0])
|
||||
return {
|
||||
'content':text,
|
||||
'input_tokens': input_tokens,
|
||||
|
133
llmengine/gemma3_it.py
Normal file
133
llmengine/gemma3_it.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/share/vllm-0.8.5/bin/python
|
||||
|
||||
# pip install accelerate
|
||||
import threading
|
||||
from time import time
|
||||
from appPublic.worker import awaitify
|
||||
from ahserver.serverenv import get_serverenv
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
from llmengine.base_chat_llm import BaseChatLLM
|
||||
|
||||
class Gemma3LLM(BaseChatLLM):
|
||||
def __init__(self, model_id):
|
||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
).eval()
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
self.messages = []
|
||||
self.model_id = model_id
|
||||
|
||||
def _build_assistant_message(self, prompt):
|
||||
return {
|
||||
"role":"assistant",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_sys_message(self, prompt):
|
||||
return {
|
||||
"role":"system",
|
||||
"content":[{"type": "text", "text": prompt}]
|
||||
}
|
||||
|
||||
def _build_user_message(self, prompt, image_path=None):
|
||||
contents = [
|
||||
{
|
||||
"type":"text", "text": prompt
|
||||
}
|
||||
]
|
||||
if image_path:
|
||||
contents.append({
|
||||
"type": "image",
|
||||
"image": image_path
|
||||
})
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": contents
|
||||
}
|
||||
|
||||
def _gen(self, messages):
|
||||
t1 = time()
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True, return_tensors="pt"
|
||||
).to(self.model.device, dtype=torch.bfloat16)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
streamer = self.get_streamer()
|
||||
generate_kwargs = dict(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
eos_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
thread = threading.Thread(target=self.model.generate,
|
||||
kwargs=generate_kwargs)
|
||||
thread.start()
|
||||
for d in self.output_generator(streamer):
|
||||
if d['done']:
|
||||
d['input_tokens'] = input_len
|
||||
yield d
|
||||
|
||||
async def generate(self, request, prompt,
|
||||
image_path=None,
|
||||
sys_prompt=None):
|
||||
messages = self.get_session_messages(request)
|
||||
if sys_prompt and len(messages) == 0:
|
||||
messages.append(self._build_sys_message(sys_prompt))
|
||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||
all_txt = ''
|
||||
for d in self._gen(messages):
|
||||
all_txt += d['text']
|
||||
d['text'] = all_txt
|
||||
messages.append(self._build_assistant_message(all_txt))
|
||||
self.set_session_message(request, messages)
|
||||
return d
|
||||
|
||||
async def strem_generate(self, request, prompt,
|
||||
image_path=None,
|
||||
sys_prompt=None):
|
||||
messages = self.get_session_messages(request)
|
||||
if sys_prompt and len(messages) == 0:
|
||||
messages.append(self._build_sys_message(sys_prompt))
|
||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||
all_txt = ''
|
||||
for d in self._gen(messages):
|
||||
yield d
|
||||
all_txt += d['text']
|
||||
data = await f(messages)
|
||||
messages.append(self._build_assistant_message(all_txt))
|
||||
self.set_session_messages(request, messages)
|
||||
|
||||
def _generate(self, prompt, image_path=None, sys_prompt=None):
|
||||
messages = self.messages
|
||||
if sys_prompt and len(messages) == 0:
|
||||
messages.append(self._build_sys_message(sys_prompt))
|
||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||
all_txt = ''
|
||||
ld = None
|
||||
for d in self._gen(messages):
|
||||
all_txt += d['text']
|
||||
ld = d
|
||||
ld['text'] = all_txt
|
||||
messages.append(self._build_assistant_message(all_txt))
|
||||
return ld
|
||||
|
||||
if __name__ == '__main__':
|
||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||
while True:
|
||||
print('input prompt')
|
||||
p = input()
|
||||
if p:
|
||||
if p == 'q':
|
||||
break;
|
||||
print('input image path')
|
||||
imgpath=input()
|
||||
t = gemma3._generate(p, image_path=imgpath)
|
||||
print(t)
|
||||
|
52
llmengine/medgemma3_it.py
Normal file
52
llmengine/medgemma3_it.py
Normal file
@ -0,0 +1,52 @@
|
||||
# pip install accelerate
|
||||
import time
|
||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
model_id = "google/medgemma-4b-it"
|
||||
|
||||
class MedgemmaLLM:
|
||||
def __init__(self, model_id):
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
self.model_id = model_id
|
||||
|
||||
# Image attribution: Stillwaterising, CC0, via Wikimedia Commons
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
|
||||
image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are an expert radiologist."}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this X-ray"},
|
||||
{"type": "image", "image": image}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True,
|
||||
return_dict=True, return_tensors="pt"
|
||||
).to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
|
||||
with torch.inference_mode():
|
||||
generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
||||
generation = generation[0][input_len:]
|
||||
|
||||
decoded = processor.decode(generation, skip_special_tokens=True)
|
||||
print(decoded)
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Example script using argparse")
|
||||
parser.add_argument('--gpus', '-g', type=str, required=False, default='0', help='Identify GPU id, default is 0, comma split')
|
||||
parser.add_argument("--stream", action="store_true", help="是否流式输出")
|
||||
parser.add_argument("--stream", action="store_true", help="是否流式输出", default=True)
|
||||
parser.add_argument('modelpath', type=str, help='Path to model folder')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
Loading…
Reference in New Issue
Block a user