bugfix
This commit is contained in:
parent
c9a26de6ca
commit
652349f4fa
@ -55,7 +55,7 @@ class BaseReranker:
|
|||||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
def rerank(self, query, docs, top_n, sys_prompt="", task=""):
|
||||||
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
|
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = self.process_inputs(pairs)
|
inputs = self.process_inputs(pairs)
|
||||||
|
@ -75,7 +75,9 @@ async def rerank(request, params_kw, *params, **kw):
|
|||||||
if isinstance(documents, str):
|
if isinstance(documents, str):
|
||||||
documents = [documents]
|
documents = [documents]
|
||||||
top_n = params_kw.top_n
|
top_n = params_kw.top_n
|
||||||
arr = await f(query, params_kw.documents, top_n = top_n)
|
if top_n is None:
|
||||||
|
top_n = 5
|
||||||
|
arr = await f(query, params_kw.documents, top_n)
|
||||||
debug(f'{arr=}, type(arr)')
|
debug(f'{arr=}, type(arr)')
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user