bugfix
This commit is contained in:
parent
e35fc0545f
commit
db6f2fa39a
@ -33,7 +33,7 @@ class BaseEmbedding:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": self.model_id.split('/')[-1],
|
||||
"model": self.model_name,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
|
@ -35,10 +35,31 @@ classs BaseReranker:
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
def rerank(self, query, docs, sys_prompt="", task=""):
|
||||
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
||||
sys_str = self.build_sys_prompt(sys_prompt)
|
||||
ass_str = self.build_assistant_prompt()
|
||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||
inputs = self.process_inputs(pairs)
|
||||
scores = self.compute_logits(inputs)
|
||||
|
||||
data = []
|
||||
for i, s in enumerate(scores):
|
||||
d = {
|
||||
'index':i,
|
||||
'relevance_score': s
|
||||
}
|
||||
data.append(d)
|
||||
data = sorted(data,
|
||||
key=lambda x: x["relevance_score"],
|
||||
reverse=True)
|
||||
if len(data) > top_n:
|
||||
data = data[:top_n]
|
||||
ret = {
|
||||
"data": data
|
||||
"object": "rerank.result",
|
||||
"model": self.model_name,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,12 +14,7 @@ from ahserver.webapp import webserver
|
||||
|
||||
from aiohttp_session import get_session
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('embeddings', embeddings)
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
txt = """embeddings api:
|
||||
helptext = """embeddings api:
|
||||
path: /v1/embeddings
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
@ -33,6 +28,34 @@ data: {
|
||||
"this is second setence"
|
||||
]
|
||||
}
|
||||
|
||||
response is a json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [0.0123, -0.0456, ...]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('embeddings', embeddings)
|
||||
rf.register('docs', docs)
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
|
||||
async def embeddings(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw.input=}')
|
||||
se = ServerEnv()
|
||||
|
@ -1,10 +1,15 @@
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
||||
from llmengine.base_reranker import BaseReranker
|
||||
from llmengine.base_reranker import BaseReranker, llm_register
|
||||
|
||||
class Qwen3Reranker(BaseReranker):
|
||||
def __init__(self, model_id, max_length=8096):
|
||||
self.odel_id = model_id
|
||||
if 'Qwen3-Reranker' not in model_id:
|
||||
e = Exception(f'{model_id} is not a Qwen3-Reranker')
|
||||
raise e
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
||||
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
|
||||
self.model_id = model_id
|
||||
self.model_name = model_id.split('/')[-1]
|
||||
|
||||
llm_register('Qwen3-Reranker', Qwen3Reranker)
|
||||
|
@ -16,5 +16,7 @@ class Qwen3Embedding(BaseEmbedding):
|
||||
# tokenizer_kwargs={"padding_side": "left"},
|
||||
# )
|
||||
self.max_length = max_length
|
||||
self.model_id = model_id
|
||||
self.model_name = model_id.split('/')[-1]
|
||||
|
||||
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
||||
|
98
llmengine/rerank.py
Normal file
98
llmengine/rerank.py
Normal file
@ -0,0 +1,98 @@
|
||||
from traceback import format_exc
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from llmengine.qwen3_reranker import *
|
||||
from llmengine.base_reranker import get_llm_class
|
||||
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.worker import awaitify
|
||||
from appPublic.log import debug, exception
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from ahserver.webapp import webserver
|
||||
|
||||
helptext = """rerank api:
|
||||
path: /v1/rerand
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data:
|
||||
{
|
||||
"model": "rerank-001",
|
||||
"query": "什么是量子计算?",
|
||||
"documents": [
|
||||
"量子计算是一种使用量子比特进行计算的方式。",
|
||||
"古典计算机使用的是二进制位。",
|
||||
"天气预报依赖于统计模型。",
|
||||
"量子计算与物理学密切相关。"
|
||||
},
|
||||
"top_n": 2
|
||||
}
|
||||
|
||||
response is a json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.95
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"relevance_score": 0.89
|
||||
}
|
||||
],
|
||||
"object": "rerank.result",
|
||||
"model": "rerank-001",
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('rerank', rerank)
|
||||
rf.register('docs', docs)
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
|
||||
async def rerank(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw.input=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
f = awaitify(engine.rerank)
|
||||
query = params_kw.query
|
||||
if query is None:
|
||||
e = exception(f'query is None')
|
||||
raise e
|
||||
if isinstance(query, str):
|
||||
input = [input]
|
||||
arr = await f(input)
|
||||
debug(f'{arr=}, type(arr)')
|
||||
return arr
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="Embedding")
|
||||
parser.add_argument('-w', '--workdir')
|
||||
parser.add_argument('-p', '--port')
|
||||
parser.add_argument('model_path')
|
||||
args = parser.parse_args()
|
||||
Klass = get_llm_class(args.model_path)
|
||||
if Klass is None:
|
||||
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||
exception(f'{e}, {format_exc()}')
|
||||
raise e
|
||||
se = ServerEnv()
|
||||
se.engine = Klass(args.model_path)
|
||||
se.engine.use_mps_if_prosible()
|
||||
workdir = args.workdir or os.getcwd()
|
||||
port = args.port
|
||||
debug(f'{args=}')
|
||||
webserver(init, workdir, port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
~/models/tsfm.env/bin/python -m llmengine.server -w ~/models/tsfm -p 9999 ~/models/Qwen/Qwen3-0.6B
|
||||
~/models/tsfm.env/bin/python -m llmengine.server -p 9999 ~/models/Qwen/Qwen3-0.6B
|
||||
|
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
~/models/tsfm.env/bin/python -m llmengine.embedding -w ~/models/embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B
|
||||
~/models/tsfm.env/bin/python -m llmengine.embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B
|
||||
|
@ -13,10 +13,6 @@
|
||||
"host":"0.0.0.0",
|
||||
"port":9995,
|
||||
"coding":"utf-8",
|
||||
"ssl_gg":{
|
||||
"crtfile":"$[workdir]$/conf/www.bsppo.com.pem",
|
||||
"keyfile":"$[workdir]$/conf/www.bsppo.com.key"
|
||||
},
|
||||
"indexes":[
|
||||
"index.html",
|
||||
"index.ui"
|
||||
@ -28,6 +24,9 @@
|
||||
},{
|
||||
"leading": "/v1/embeddings",
|
||||
"registerfunction": "embeddings"
|
||||
},{
|
||||
"leading": "/docs",
|
||||
"registerfunction": "docs"
|
||||
}
|
||||
],
|
||||
"processors":[
|
||||
|
Loading…
Reference in New Issue
Block a user