diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 8b0aba5..37d6b37 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -45,8 +45,9 @@ classs BaseReranker: def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task) - inputs = self.process_inputs(pairs) - scores = self.compute_logits(inputs) + with torch.no_grad(): + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) data = [] for i, s in enumerate(scores): d = { @@ -68,4 +69,4 @@ classs BaseReranker: "total_tokens": 0 } } - + return ret diff --git a/llmengine/rerank.py b/llmengine/rerank.py index 5ccb4df..4d13454 100644 --- a/llmengine/rerank.py +++ b/llmengine/rerank.py @@ -75,7 +75,7 @@ async def rerank(request, params_kw, *params, **kw): return arr def main(): - parser = argparse.ArgumentParser(prog="Embedding") + parser = argparse.ArgumentParser(prog="Rerank") parser.add_argument('-w', '--workdir') parser.add_argument('-p', '--port') parser.add_argument('model_path')