This commit is contained in:
yumoqing 2025-06-21 11:19:28 +08:00
parent c9a26de6ca
commit 652349f4fa
2 changed files with 4 additions and 2 deletions

View File

@ -55,7 +55,7 @@ class BaseReranker:
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
return pairs return pairs
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): def rerank(self, query, docs, top_n, sys_prompt="", task=""):
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
with torch.no_grad(): with torch.no_grad():
inputs = self.process_inputs(pairs) inputs = self.process_inputs(pairs)

View File

@ -75,7 +75,9 @@ async def rerank(request, params_kw, *params, **kw):
if isinstance(documents, str): if isinstance(documents, str):
documents = [documents] documents = [documents]
top_n = params_kw.top_n top_n = params_kw.top_n
arr = await f(query, params_kw.documents, top_n = top_n) if top_n is None:
top_n = 5
arr = await f(query, params_kw.documents, top_n)
debug(f'{arr=}, type(arr)') debug(f'{arr=}, type(arr)')
return arr return arr