llmengine/llmengine/base_reranker.py
2025-06-19 17:31:10 +08:00

41 lines
1.4 KiB
Python

import torch
classs BaseReranker:
def process_input(self, pairs):
inputs = self.tokenizer(
pairs, padding=False, truncation='longest_first',
return_attention_mask=False, max_length=self.max_length
)
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
def build_sys_prompt(self, sys_prompt):
return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>"
def build_user_prompt(self, query, docs, instruct=''):
return f'<|im_start|>user\n<Instruct>: {instruct}\n<Query>:{query}\n<Document>:\n{doc}<|im_end|>'
def build_assistant_prompt(self):
return "<|im_start|>assistant\n<think>\n\n</think>\n\n"
def compute_logits(self, inputs, **kwargs):
batch_scores = self.model(**inputs).logits[:, -1, :]
# true_vector = batch_scores[:, token_true_id]
# false_vector = batch_scores[:, token_false_id]
# batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores
def rerank(self, query, docs, sys_prompt="", task=""):
sys_str = self.build_sys_prompt(sys_prompt)
ass_str = self.build_assistant_prompt()
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
inputs = self.process_inputs(pairs)
scores = self.compute_logits(inputs)