From 652349f4fad2e3b8e325e943c3672d78592785e7 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Sat, 21 Jun 2025 11:19:28 +0800 Subject: [PATCH] bugfix --- llmengine/base_reranker.py | 2 +- llmengine/rerank.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 892ec91..8ef86a2 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -55,7 +55,7 @@ class BaseReranker: pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] return pairs - def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): + def rerank(self, query, docs, top_n, sys_prompt="", task=""): pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) with torch.no_grad(): inputs = self.process_inputs(pairs) diff --git a/llmengine/rerank.py b/llmengine/rerank.py index a817a6b..c194859 100644 --- a/llmengine/rerank.py +++ b/llmengine/rerank.py @@ -75,7 +75,9 @@ async def rerank(request, params_kw, *params, **kw): if isinstance(documents, str): documents = [documents] top_n = params_kw.top_n - arr = await f(query, params_kw.documents, top_n = top_n) + if top_n is None: + top_n = 5 + arr = await f(query, params_kw.documents, top_n) debug(f'{arr=}, type(arr)') return arr