# # Modified from LLaVA/predict.py # Please see ACKNOWLEDGEMENTS for details about LICENSE # import os import torch import time from PIL import Image from llava.utils import disable_torch_init from llava.conversation import conv_templates from llava.model.builder import load_pretrained_model from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from ahserver.webapp import webapp from ahserver.serverenv import ServerEnv from appPublic.jsonConfig import getConfig from appPublic.log import debug, exception, error from appPublic.worker import awaitify class FastVLM: def __init__(self): self.config = getConfig() model_path = self.config.model_path """ generation_config = None if os.path.exists(os.path.join(model_path, 'generation_config.json')): generation_config = os.path.join(model_path, '.generation_config.json') os.rename(os.path.join(model_path, 'generation_config.json'), generation_config) """ # Load model disable_torch_init() model_name = get_model_name_from_path(model_path) model_base = None device = self.config.device tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, device=device) self.tokenizer = tokenizer self.model = model self.image_processor = image_processor self.context_len = context_len def _generate(self, image_file, prompt, temperature=0.2, top_p=None, num_beams=1, conv_mode='qwen_2'): qs = prompt t1 = time.time() if self.model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() # Set the pad token id for generation self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id # Tokenize prompt input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') \ .unsqueeze(0).to(self.model.device) # Load and preprocess image image = Image.open(image_file).convert('RGB') image_tensor = process_images([image], self.image_processor, self.model.config)[0] # Run inference with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=image_tensor.unsqueeze(0).half(), image_sizes=[image.size], do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, num_beams=num_beams, max_new_tokens=256, use_cache=True) outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() t2 = time.time() return { 'timecost': t2 - t1, 'content': outputs } async def generate(self, image_file, prompt): f = awaitify(self._generate) return await f(image_file, promot) def init(): g = ServerEnv() k = FastVLM() g.generate = k.generate if __name__ == "__main__": webapp(init)