import torch classs BaseReranker: def process_input(self, pairs): inputs = self.tokenizer( pairs, padding=False, truncation='longest_first', return_attention_mask=False, max_length=self.max_length ) inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs def build_sys_prompt(self, sys_prompt): return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>" def build_user_prompt(self, query, docs, instruct=''): return f'<|im_start|>user\n: {instruct}\n:{query}\n:\n{doc}<|im_end|>' def build_assistant_prompt(self): return "<|im_start|>assistant\n\n\n\n\n" def compute_logits(self, inputs, **kwargs): batch_scores = self.model(**inputs).logits[:, -1, :] # true_vector = batch_scores[:, token_true_id] # false_vector = batch_scores[:, token_false_id] # batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() return scores def rerank(self, query, docs, sys_prompt="", task=""): sys_str = self.build_sys_prompt(sys_prompt) ass_str = self.build_assistant_prompt() pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] inputs = self.process_inputs(pairs) scores = self.compute_logits(inputs)