diff --git a/llmengine/bge_reranker.py b/llmengine/bge_reranker.py index 24aa44f..38486c4 100644 --- a/llmengine/bge_reranker.py +++ b/llmengine/bge_reranker.py @@ -25,6 +25,7 @@ class BgeReranker(BaseReranker): def compute_logits(self, inputs): scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float() + scores = [ s.item() for s in scores ] return scores llm_register('bge-reranker', BgeReranker)