bugfix
This commit is contained in:
parent
db6f2fa39a
commit
388ece4b18
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
|
||||
classs BaseReranker:
|
||||
def __init__(self, model_id, **kw):
|
||||
self.model_id = model_id
|
||||
|
||||
def use_mps_if_prosible(self):
|
||||
if torch.backends.mps.is_available():
|
||||
@ -35,10 +37,14 @@ classs BaseReranker:
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
||||
def build_pairs(self, query, docs, sys_prompt="", task=""):
|
||||
sys_str = self.build_sys_prompt(sys_prompt)
|
||||
ass_str = self.build_assistant_prompt()
|
||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||
return pairs
|
||||
|
||||
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
||||
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
|
||||
inputs = self.process_inputs(pairs)
|
||||
scores = self.compute_logits(inputs)
|
||||
data = []
|
||||
|
30
llmengine/bge_reranker.py
Normal file
30
llmengine/bge_reranker.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from llmengine.base_reranker import BaseReranker, llm_register
|
||||
|
||||
class BgeReranker(BaseReranker):
|
||||
def __init__(self, model_id, max_length=8096):
|
||||
if 'bge-reranker' not in model_id:
|
||||
e = Exception(f'{model_id} is not a bge-reranker')
|
||||
raise e
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.model_id = model_id
|
||||
self.model_name = model_id.split('/')[-1]
|
||||
|
||||
def build_pairs(self, query, docs, **kw):
|
||||
return [[query, doc] for doc in docs]
|
||||
|
||||
def process_inputs(self, pairs):
|
||||
inputs = tokenizer(pairs, padding=True,
|
||||
truncation=True, return_tensors='pt', max_length=512)
|
||||
return inputs
|
||||
|
||||
def compute_logits(self, inputs):
|
||||
scores = self.model(**inputs,
|
||||
return_dict=True).logits.view(-1, ).float()
|
||||
return scores
|
||||
|
||||
llm_register('bge-reranker', BgeReranker)
|
Loading…
Reference in New Issue
Block a user