This commit is contained in:
yumoqing 2025-06-20 22:39:54 +08:00
parent e35fc0545f
commit db6f2fa39a
9 changed files with 165 additions and 17 deletions

View File

@ -33,7 +33,7 @@ class BaseEmbedding:
return { return {
"object": "list", "object": "list",
"data": data, "data": data,
"model": self.model_id.split('/')[-1], "model": self.model_name,
"usage": { "usage": {
"prompt_tokens": 0, "prompt_tokens": 0,
"total_tokens": 0 "total_tokens": 0

View File

@ -35,10 +35,31 @@ classs BaseReranker:
scores = batch_scores[:, 1].exp().tolist() scores = batch_scores[:, 1].exp().tolist()
return scores return scores
def rerank(self, query, docs, sys_prompt="", task=""): def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
sys_str = self.build_sys_prompt(sys_prompt) sys_str = self.build_sys_prompt(sys_prompt)
ass_str = self.build_assistant_prompt() ass_str = self.build_assistant_prompt()
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
inputs = self.process_inputs(pairs) inputs = self.process_inputs(pairs)
scores = self.compute_logits(inputs) scores = self.compute_logits(inputs)
data = []
for i, s in enumerate(scores):
d = {
'index':i,
'relevance_score': s
}
data.append(d)
data = sorted(data,
key=lambda x: x["relevance_score"],
reverse=True)
if len(data) > top_n:
data = data[:top_n]
ret = {
"data": data
"object": "rerank.result",
"model": self.model_name,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}

View File

@ -14,12 +14,7 @@ from ahserver.webapp import webserver
from aiohttp_session import get_session from aiohttp_session import get_session
def init(): helptext = """embeddings api:
rf = RegisterFunction()
rf.register('embeddings', embeddings)
async def docs(request, params_kw, *params, **kw):
txt = """embeddings api:
path: /v1/embeddings path: /v1/embeddings
headers: { headers: {
"Content-Type": "application/json" "Content-Type": "application/json"
@ -33,6 +28,34 @@ data: {
"this is second setence" "this is second setence"
] ]
} }
response is a json
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.0123, -0.0456, ...]
}
],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
"""
def init():
rf = RegisterFunction()
rf.register('embeddings', embeddings)
rf.register('docs', docs)
async def docs(request, params_kw, *params, **kw):
return helptext
async def embeddings(request, params_kw, *params, **kw): async def embeddings(request, params_kw, *params, **kw):
debug(f'{params_kw.input=}') debug(f'{params_kw.input=}')
se = ServerEnv() se = ServerEnv()

View File

@ -1,10 +1,15 @@
import torch import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from llmengine.base_reranker import BaseReranker from llmengine.base_reranker import BaseReranker, llm_register
class Qwen3Reranker(BaseReranker): class Qwen3Reranker(BaseReranker):
def __init__(self, model_id, max_length=8096): def __init__(self, model_id, max_length=8096):
self.odel_id = model_id if 'Qwen3-Reranker' not in model_id:
e = Exception(f'{model_id} is not a Qwen3-Reranker')
raise e
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left') self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
llm_register('Qwen3-Reranker', Qwen3Reranker)

View File

@ -16,5 +16,7 @@ class Qwen3Embedding(BaseEmbedding):
# tokenizer_kwargs={"padding_side": "left"}, # tokenizer_kwargs={"padding_side": "left"},
# ) # )
self.max_length = max_length self.max_length = max_length
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
llm_register('Qwen3-Embedding', Qwen3Embedding) llm_register('Qwen3-Embedding', Qwen3Embedding)

98
llmengine/rerank.py Normal file
View File

@ -0,0 +1,98 @@
from traceback import format_exc
import os
import sys
import argparse
from llmengine.qwen3_reranker import *
from llmengine.base_reranker import get_llm_class
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify
from appPublic.log import debug, exception
from ahserver.serverenv import ServerEnv
from ahserver.webapp import webserver
helptext = """rerank api:
path: /v1/rerand
headers: {
"Content-Type": "application/json"
}
data:
{
"model": "rerank-001",
"query": "什么是量子计算?",
"documents": [
"量子计算是一种使用量子比特进行计算的方式。",
"古典计算机使用的是二进制位。",
"天气预报依赖于统计模型。",
"量子计算与物理学密切相关。"
},
"top_n": 2
}
response is a json
{
"data": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 3,
"relevance_score": 0.89
}
],
"object": "rerank.result",
"model": "rerank-001",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
"""
def init():
rf = RegisterFunction()
rf.register('rerank', rerank)
rf.register('docs', docs)
async def docs(request, params_kw, *params, **kw):
return helptext
async def rerank(request, params_kw, *params, **kw):
debug(f'{params_kw.input=}')
se = ServerEnv()
engine = se.engine
f = awaitify(engine.rerank)
query = params_kw.query
if query is None:
e = exception(f'query is None')
raise e
if isinstance(query, str):
input = [input]
arr = await f(input)
debug(f'{arr=}, type(arr)')
return arr
def main():
parser = argparse.ArgumentParser(prog="Embedding")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
parser.add_argument('model_path')
args = parser.parse_args()
Klass = get_llm_class(args.model_path)
if Klass is None:
e = Exception(f'{args.model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}')
raise e
se = ServerEnv()
se.engine = Klass(args.model_path)
se.engine.use_mps_if_prosible()
workdir = args.workdir or os.getcwd()
port = args.port
debug(f'{args=}')
webserver(init, workdir, port)
if __name__ == '__main__':
main()

View File

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
~/models/tsfm.env/bin/python -m llmengine.server -w ~/models/tsfm -p 9999 ~/models/Qwen/Qwen3-0.6B ~/models/tsfm.env/bin/python -m llmengine.server -p 9999 ~/models/Qwen/Qwen3-0.6B

View File

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
~/models/tsfm.env/bin/python -m llmengine.embedding -w ~/models/embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B ~/models/tsfm.env/bin/python -m llmengine.embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B

View File

@ -13,10 +13,6 @@
"host":"0.0.0.0", "host":"0.0.0.0",
"port":9995, "port":9995,
"coding":"utf-8", "coding":"utf-8",
"ssl_gg":{
"crtfile":"$[workdir]$/conf/www.bsppo.com.pem",
"keyfile":"$[workdir]$/conf/www.bsppo.com.key"
},
"indexes":[ "indexes":[
"index.html", "index.html",
"index.ui" "index.ui"
@ -28,6 +24,9 @@
},{ },{
"leading": "/v1/embeddings", "leading": "/v1/embeddings",
"registerfunction": "embeddings" "registerfunction": "embeddings"
},{
"leading": "/docs",
"registerfunction": "docs"
} }
], ],
"processors":[ "processors":[