66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
import torch
|
|
|
|
classs BaseReranker:
|
|
|
|
def use_mps_if_prosible(self):
|
|
if torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
self.model = self.model.to(device)
|
|
|
|
def process_input(self, pairs):
|
|
inputs = self.tokenizer(
|
|
pairs, padding=False, truncation='longest_first',
|
|
return_attention_mask=False, max_length=self.max_length
|
|
)
|
|
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
|
for key in inputs:
|
|
inputs[key] = inputs[key].to(self.model.device)
|
|
return inputs
|
|
|
|
def build_sys_prompt(self, sys_prompt):
|
|
return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>"
|
|
|
|
def build_user_prompt(self, query, docs, instruct=''):
|
|
return f'<|im_start|>user\n<Instruct>: {instruct}\n<Query>:{query}\n<Document>:\n{doc}<|im_end|>'
|
|
|
|
def build_assistant_prompt(self):
|
|
return "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
|
|
def compute_logits(self, inputs, **kwargs):
|
|
batch_scores = self.model(**inputs).logits[:, -1, :]
|
|
# true_vector = batch_scores[:, token_true_id]
|
|
# false_vector = batch_scores[:, token_false_id]
|
|
# batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
|
scores = batch_scores[:, 1].exp().tolist()
|
|
return scores
|
|
|
|
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
|
sys_str = self.build_sys_prompt(sys_prompt)
|
|
ass_str = self.build_assistant_prompt()
|
|
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
|
inputs = self.process_inputs(pairs)
|
|
scores = self.compute_logits(inputs)
|
|
data = []
|
|
for i, s in enumerate(scores):
|
|
d = {
|
|
'index':i,
|
|
'relevance_score': s
|
|
}
|
|
data.append(d)
|
|
data = sorted(data,
|
|
key=lambda x: x["relevance_score"],
|
|
reverse=True)
|
|
if len(data) > top_n:
|
|
data = data[:top_n]
|
|
ret = {
|
|
"data": data
|
|
"object": "rerank.result",
|
|
"model": self.model_name,
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"total_tokens": 0
|
|
}
|
|
}
|
|
|