diff --git a/llmengine/base_embedding.py b/llmengine/base_embedding.py index c92939c..68c6ba6 100644 --- a/llmengine/base_embedding.py +++ b/llmengine/base_embedding.py @@ -33,7 +33,7 @@ class BaseEmbedding: return { "object": "list", "data": data, - "model": self.model_id.split('/')[-1], + "model": self.model_name, "usage": { "prompt_tokens": 0, "total_tokens": 0 diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index d5567e8..4f3ea9c 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -35,10 +35,31 @@ classs BaseReranker: scores = batch_scores[:, 1].exp().tolist() return scores - def rerank(self, query, docs, sys_prompt="", task=""): + def rerank(self, query, docs, top_n=5, sys_prompt="", task=""): sys_str = self.build_sys_prompt(sys_prompt) ass_str = self.build_assistant_prompt() pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] inputs = self.process_inputs(pairs) scores = self.compute_logits(inputs) - + data = [] + for i, s in enumerate(scores): + d = { + 'index':i, + 'relevance_score': s + } + data.append(d) + data = sorted(data, + key=lambda x: x["relevance_score"], + reverse=True) + if len(data) > top_n: + data = data[:top_n] + ret = { + "data": data + "object": "rerank.result", + "model": self.model_name, + "usage": { + "prompt_tokens": 0, + "total_tokens": 0 + } + } + diff --git a/llmengine/embedding.py b/llmengine/embedding.py index 6d42237..a3cf731 100644 --- a/llmengine/embedding.py +++ b/llmengine/embedding.py @@ -14,12 +14,7 @@ from ahserver.webapp import webserver from aiohttp_session import get_session -def init(): - rf = RegisterFunction() - rf.register('embeddings', embeddings) - -async def docs(request, params_kw, *params, **kw): - txt = """embeddings api: +helptext = """embeddings api: path: /v1/embeddings headers: { "Content-Type": "application/json" @@ -33,6 +28,34 @@ data: { "this is second setence" ] } + +response is a json +{ + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [0.0123, -0.0456, ...] + } + ], + "model": "text-embedding-3-small", + "usage": { + "prompt_tokens": 0, + "total_tokens": 0 + } +} +""" + + +def init(): + rf = RegisterFunction() + rf.register('embeddings', embeddings) + rf.register('docs', docs) + +async def docs(request, params_kw, *params, **kw): + return helptext + async def embeddings(request, params_kw, *params, **kw): debug(f'{params_kw.input=}') se = ServerEnv() diff --git a/llmengine/qwen3_reranker.py b/llmengine/qwen3_reranker.py index 1232536..1fb9231 100644 --- a/llmengine/qwen3_reranker.py +++ b/llmengine/qwen3_reranker.py @@ -1,10 +1,15 @@ import torch from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM -from llmengine.base_reranker import BaseReranker +from llmengine.base_reranker import BaseReranker, llm_register class Qwen3Reranker(BaseReranker): def __init__(self, model_id, max_length=8096): - self.odel_id = model_id + if 'Qwen3-Reranker' not in model_id: + e = Exception(f'{model_id} is not a Qwen3-Reranker') + raise e self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left') self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() + self.model_id = model_id + self.model_name = model_id.split('/')[-1] +llm_register('Qwen3-Reranker', Qwen3Reranker) diff --git a/llmengine/qwen3embedding.py b/llmengine/qwen3embedding.py index 9daa664..4b82b61 100644 --- a/llmengine/qwen3embedding.py +++ b/llmengine/qwen3embedding.py @@ -16,5 +16,7 @@ class Qwen3Embedding(BaseEmbedding): # tokenizer_kwargs={"padding_side": "left"}, # ) self.max_length = max_length + self.model_id = model_id + self.model_name = model_id.split('/')[-1] llm_register('Qwen3-Embedding', Qwen3Embedding) diff --git a/llmengine/rerank.py b/llmengine/rerank.py new file mode 100644 index 0000000..5ccb4df --- /dev/null +++ b/llmengine/rerank.py @@ -0,0 +1,98 @@ +from traceback import format_exc +import os +import sys +import argparse +from llmengine.qwen3_reranker import * +from llmengine.base_reranker import get_llm_class + +from appPublic.registerfunction import RegisterFunction +from appPublic.worker import awaitify +from appPublic.log import debug, exception +from ahserver.serverenv import ServerEnv +from ahserver.webapp import webserver + +helptext = """rerank api: +path: /v1/rerand +headers: { + "Content-Type": "application/json" +} +data: +{ + "model": "rerank-001", + "query": "什么是量子计算?", + "documents": [ + "量子计算是一种使用量子比特进行计算的方式。", + "古典计算机使用的是二进制位。", + "天气预报依赖于统计模型。", + "量子计算与物理学密切相关。" + }, + "top_n": 2 +} + +response is a json +{ + "data": [ + { + "index": 0, + "relevance_score": 0.95 + }, + { + "index": 3, + "relevance_score": 0.89 + } + ], + "object": "rerank.result", + "model": "rerank-001", + "usage": { + "prompt_tokens": 0, + "total_tokens": 0 + } +} +""" + + +def init(): + rf = RegisterFunction() + rf.register('rerank', rerank) + rf.register('docs', docs) + +async def docs(request, params_kw, *params, **kw): + return helptext + +async def rerank(request, params_kw, *params, **kw): + debug(f'{params_kw.input=}') + se = ServerEnv() + engine = se.engine + f = awaitify(engine.rerank) + query = params_kw.query + if query is None: + e = exception(f'query is None') + raise e + if isinstance(query, str): + input = [input] + arr = await f(input) + debug(f'{arr=}, type(arr)') + return arr + +def main(): + parser = argparse.ArgumentParser(prog="Embedding") + parser.add_argument('-w', '--workdir') + parser.add_argument('-p', '--port') + parser.add_argument('model_path') + args = parser.parse_args() + Klass = get_llm_class(args.model_path) + if Klass is None: + e = Exception(f'{args.model_path} has not mapping to a model class') + exception(f'{e}, {format_exc()}') + raise e + se = ServerEnv() + se.engine = Klass(args.model_path) + se.engine.use_mps_if_prosible() + workdir = args.workdir or os.getcwd() + port = args.port + debug(f'{args=}') + webserver(init, workdir, port) + +if __name__ == '__main__': + main() + diff --git a/test/chat/Qwen3-0.6B b/test/chat/Qwen3-0.6B index 9e9d27d..af7b5be 100755 --- a/test/chat/Qwen3-0.6B +++ b/test/chat/Qwen3-0.6B @@ -1,3 +1,3 @@ #!/bin/bash -~/models/tsfm.env/bin/python -m llmengine.server -w ~/models/tsfm -p 9999 ~/models/Qwen/Qwen3-0.6B +~/models/tsfm.env/bin/python -m llmengine.server -p 9999 ~/models/Qwen/Qwen3-0.6B diff --git a/test/embeddings/Qwen3-Embedding-0.6B b/test/embeddings/Qwen3-Embedding-0.6B index fa4ad7e..2799341 100755 --- a/test/embeddings/Qwen3-Embedding-0.6B +++ b/test/embeddings/Qwen3-Embedding-0.6B @@ -1,3 +1,3 @@ #!/bin/bash -~/models/tsfm.env/bin/python -m llmengine.embedding -w ~/models/embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B +~/models/tsfm.env/bin/python -m llmengine.embedding -p 9998 ~/models/Qwen/Qwen3-Embedding-0.6B diff --git a/test/embeddings/conf/config.json b/test/embeddings/conf/config.json index 33a465b..3c644f2 100644 --- a/test/embeddings/conf/config.json +++ b/test/embeddings/conf/config.json @@ -13,10 +13,6 @@ "host":"0.0.0.0", "port":9995, "coding":"utf-8", - "ssl_gg":{ - "crtfile":"$[workdir]$/conf/www.bsppo.com.pem", - "keyfile":"$[workdir]$/conf/www.bsppo.com.key" - }, "indexes":[ "index.html", "index.ui" @@ -28,6 +24,9 @@ },{ "leading": "/v1/embeddings", "registerfunction": "embeddings" + },{ + "leading": "/docs", + "registerfunction": "docs" } ], "processors":[