bugfix
This commit is contained in:
parent
e35fc0545f
commit
db6f2fa39a
@ -33,7 +33,7 @@ class BaseEmbedding:
|
|||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
"model": self.model_id.split('/')[-1],
|
"model": self.model_name,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
|
@ -35,10 +35,31 @@ classs BaseReranker:
|
|||||||
scores = batch_scores[:, 1].exp().tolist()
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def rerank(self, query, docs, sys_prompt="", task=""):
|
def rerank(self, query, docs, top_n=5, sys_prompt="", task=""):
|
||||||
sys_str = self.build_sys_prompt(sys_prompt)
|
sys_str = self.build_sys_prompt(sys_prompt)
|
||||||
ass_str = self.build_assistant_prompt()
|
ass_str = self.build_assistant_prompt()
|
||||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||||
inputs = self.process_inputs(pairs)
|
inputs = self.process_inputs(pairs)
|
||||||
scores = self.compute_logits(inputs)
|
scores = self.compute_logits(inputs)
|
||||||
|
data = []
|
||||||
|
for i, s in enumerate(scores):
|
||||||
|
d = {
|
||||||
|
'index':i,
|
||||||
|
'relevance_score': s
|
||||||
|
}
|
||||||
|
data.append(d)
|
||||||
|
data = sorted(data,
|
||||||
|
key=lambda x: x["relevance_score"],
|
||||||
|
reverse=True)
|
||||||
|
if len(data) > top_n:
|
||||||
|
data = data[:top_n]
|
||||||
|
ret = {
|
||||||
|
"data": data
|
||||||
|
"object": "rerank.result",
|
||||||
|
"model": self.model_name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -14,12 +14,7 @@ from ahserver.webapp import webserver
|
|||||||
|
|
||||||
from aiohttp_session import get_session
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
def init():
|
helptext = """embeddings api:
|
||||||
rf = RegisterFunction()
|
|
||||||
rf.register('embeddings', embeddings)
|
|
||||||
|
|
||||||
async def docs(request, params_kw, *params, **kw):
|
|
||||||
txt = """embeddings api:
|
|
||||||
path: /v1/embeddings
|
path: /v1/embeddings
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
@ -33,6 +28,34 @@ data: {
|
|||||||
"this is second setence"
|
"this is second setence"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response is a json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [0.0123, -0.0456, ...]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('embeddings', embeddings)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
async def embeddings(request, params_kw, *params, **kw):
|
async def embeddings(request, params_kw, *params, **kw):
|
||||||
debug(f'{params_kw.input=}')
|
debug(f'{params_kw.input=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
||||||
from llmengine.base_reranker import BaseReranker
|
from llmengine.base_reranker import BaseReranker, llm_register
|
||||||
|
|
||||||
class Qwen3Reranker(BaseReranker):
|
class Qwen3Reranker(BaseReranker):
|
||||||
def __init__(self, model_id, max_length=8096):
|
def __init__(self, model_id, max_length=8096):
|
||||||
self.odel_id = model_id
|
if 'Qwen3-Reranker' not in model_id:
|
||||||
|
e = Exception(f'{model_id} is not a Qwen3-Reranker')
|
||||||
|
raise e
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
||||||
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
|
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
|
llm_register('Qwen3-Reranker', Qwen3Reranker)
|
||||||
|
@ -16,5 +16,7 @@ class Qwen3Embedding(BaseEmbedding):
|
|||||||
# tokenizer_kwargs={"padding_side": "left"},
|
# tokenizer_kwargs={"padding_side": "left"},
|
||||||
# )
|
# )
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
||||||
|
98
llmengine/rerank.py
Normal file
98
llmengine/rerank.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.qwen3_reranker import *
|
||||||
|
from llmengine.base_reranker import get_llm_class
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
helptext = """rerank api:
|
||||||
|
path: /v1/rerand
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data:
|
||||||
|
{
|
||||||
|
"model": "rerank-001",
|
||||||
|
"query": "什么是量子计算?",
|
||||||
|
"documents": [
|
||||||
|
"量子计算是一种使用量子比特进行计算的方式。",
|
||||||
|
"古典计算机使用的是二进制位。",
|
||||||
|
"天气预报依赖于统计模型。",
|
||||||
|
"量子计算与物理学密切相关。"
|
||||||
|
},
|
||||||
|
"top_n": 2
|
||||||
|
}
|
||||||
|
|
||||||
|
response is a json
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.95
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 3,
|
||||||
|
"relevance_score": 0.89
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"object": "rerank.result",
|
||||||
|
"model": "rerank-001",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('rerank', rerank)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def rerank(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.input=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.rerank)
|
||||||
|
query = params_kw.query
|
||||||
|
if query is None:
|
||||||
|
e = exception(f'query is None')
|
||||||
|
raise e
|
||||||
|
if isinstance(query, str):
|
||||||
|
input = [input]
|
||||||
|
arr = await f(input)
|
||||||
|
debug(f'{arr=}, type(arr)')
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="Embedding")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
se.engine.use_mps_if_prosible()
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
@ -1,3 +1,3 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
~/models/tsfm.env/bin/python -m llmengine.server -w ~/models/tsfm -p 9999 ~/models/Qwen/Qwen3-0.6B
|
~/models/tsfm.env/bin/python -m llmengine.server -p 9999 ~/models/Qwen/Qwen3-0.6B
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
~/models/tsfm.env/bin/python -m llmengine.embedding -w ~/models/embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B
|
~/models/tsfm.env/bin/python -m llmengine.embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B
|
||||||
|
@ -13,10 +13,6 @@
|
|||||||
"host":"0.0.0.0",
|
"host":"0.0.0.0",
|
||||||
"port":9995,
|
"port":9995,
|
||||||
"coding":"utf-8",
|
"coding":"utf-8",
|
||||||
"ssl_gg":{
|
|
||||||
"crtfile":"$[workdir]$/conf/www.bsppo.com.pem",
|
|
||||||
"keyfile":"$[workdir]$/conf/www.bsppo.com.key"
|
|
||||||
},
|
|
||||||
"indexes":[
|
"indexes":[
|
||||||
"index.html",
|
"index.html",
|
||||||
"index.ui"
|
"index.ui"
|
||||||
@ -28,6 +24,9 @@
|
|||||||
},{
|
},{
|
||||||
"leading": "/v1/embeddings",
|
"leading": "/v1/embeddings",
|
||||||
"registerfunction": "embeddings"
|
"registerfunction": "embeddings"
|
||||||
|
},{
|
||||||
|
"leading": "/docs",
|
||||||
|
"registerfunction": "docs"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"processors":[
|
"processors":[
|
||||||
|
Loading…
Reference in New Issue
Block a user