From 388ece4b1891582c2a29bcbb087e11d3d7ad2898 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Sat, 21 Jun 2025 09:56:42 +0800 Subject: [PATCH] bugfix --- llmengine/base_reranker.py | 10 ++++++++-- llmengine/bge_reranker.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 llmengine/bge_reranker.py diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 4f3ea9c..8b0aba5 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -1,7 +1,9 @@ import torch classs BaseReranker: - + def __init__(self, model_id, **kw): + self.model_id = model_id + def use_mps_if_prosible(self): if torch.backends.mps.is_available(): device = torch.device("mps") @@ -35,10 +37,14 @@ classs BaseReranker: scores = batch_scores[:, 1].exp().tolist() return scores - def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): + def build_pairs(self, query, docs, sys_prompt="", task=""): sys_str = self.build_sys_prompt(sys_prompt) ass_str = self.build_assistant_prompt() pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] + return pairs + + def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): + pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) inputs = self.process_inputs(pairs) scores = self.compute_logits(inputs) data = [] diff --git a/llmengine/bge_reranker.py b/llmengine/bge_reranker.py new file mode 100644 index 0000000..9c830fd --- /dev/null +++ b/llmengine/bge_reranker.py @@ -0,0 +1,30 @@ +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from llmengine.base_reranker import BaseReranker, llm_register + +class BgeReranker(BaseReranker): + def __init__(self, model_id, max_length=8096): + if 'bge-reranker' not in model_id: + e = Exception(f'{model_id} is not a bge-reranker') + raise e + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForSequenceClassification.from_pretrained(model_id) + model.eval() + self.model = model + self.model_id = model_id + self.model_name = model_id.split('/')[-1] + + def build_pairs(self, query, docs, **kw): + return [[query, doc] for doc in docs] + + def process_inputs(self, pairs): + inputs = tokenizer(pairs, padding=True, + truncation=True, return_tensors='pt', max_length=512) + return inputs + + def compute_logits(self, inputs): + scores = self.model(**inputs, + return_dict=True).logits.view(-1, ).float() + return scores + +llm_register('bge-reranker', BgeReranker)