107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
#
|
|
# Modified from LLaVA/predict.py
|
|
# Please see ACKNOWLEDGEMENTS for details about LICENSE
|
|
#
|
|
import os
|
|
import torch
|
|
import time
|
|
from PIL import Image
|
|
|
|
from llava.utils import disable_torch_init
|
|
from llava.conversation import conv_templates
|
|
from llava.model.builder import load_pretrained_model
|
|
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
|
from ahserver.webapp import webapp
|
|
from ahserver.serverenv import ServerEnv
|
|
from appPublic.jsonConfig import getConfig
|
|
from appPublic.log import debug, exception, error
|
|
from appPublic.worker import awaitify
|
|
|
|
class FastVLM:
|
|
def __init__(self):
|
|
self.config = getConfig()
|
|
model_path = self.config.model_path
|
|
"""
|
|
generation_config = None
|
|
if os.path.exists(os.path.join(model_path, 'generation_config.json')):
|
|
generation_config = os.path.join(model_path, '.generation_config.json')
|
|
os.rename(os.path.join(model_path, 'generation_config.json'),
|
|
generation_config)
|
|
"""
|
|
|
|
# Load model
|
|
disable_torch_init()
|
|
model_name = get_model_name_from_path(model_path)
|
|
model_base = None
|
|
device = self.config.device
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, device=device)
|
|
self.tokenizer = tokenizer
|
|
self.model = model
|
|
self.image_processor = image_processor
|
|
self.context_len = context_len
|
|
|
|
def _generate(self, image_file, prompt,
|
|
temperature=0.2,
|
|
top_p=None,
|
|
num_beams=1,
|
|
conv_mode='qwen_2'):
|
|
qs = prompt
|
|
t1 = time.time()
|
|
if self.model.config.mm_use_im_start_end:
|
|
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
|
|
else:
|
|
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
|
conv = conv_templates[conv_mode].copy()
|
|
conv.append_message(conv.roles[0], qs)
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
|
|
# Set the pad token id for generation
|
|
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
|
|
|
# Tokenize prompt
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') \
|
|
.unsqueeze(0).to(self.model.device)
|
|
|
|
# Load and preprocess image
|
|
image = Image.open(image_file).convert('RGB')
|
|
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
|
|
|
|
# Run inference
|
|
with torch.inference_mode():
|
|
output_ids = self.model.generate(
|
|
input_ids,
|
|
images=image_tensor.unsqueeze(0).half(),
|
|
image_sizes=[image.size],
|
|
do_sample=True if temperature > 0 else False,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
num_beams=num_beams,
|
|
max_new_tokens=256,
|
|
use_cache=True)
|
|
|
|
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
|
t2 = time.time()
|
|
return {
|
|
'timecost': t2 - t1,
|
|
'content': outputs
|
|
}
|
|
debug(f'Exception happened .......')
|
|
return None
|
|
|
|
async def generate(self, image_file, prompt):
|
|
f = awaitify(self._generate)
|
|
return await f(image_file, prompt)
|
|
|
|
fastvlm = None
|
|
def init():
|
|
global fastvlm
|
|
g = ServerEnv()
|
|
g.fastvlm = fastvlm
|
|
fastvlm = FastVLM()
|
|
g.generate = fastvlm.generate
|
|
|
|
if __name__ == "__main__":
|
|
webapp(init)
|