diff --git a/llmengine/base_chat_llm.py b/llmengine/base_chat_llm.py index 1559b86..c4634e8 100644 --- a/llmengine/base_chat_llm.py +++ b/llmengine/base_chat_llm.py @@ -17,13 +17,14 @@ def get_llm_class(model_path): for k,klass in model_pathMap.items(): if len(model_path.split(k)) > 1: return klass + print(f'{model_pathMap=}') return None class BaseChatLLM: def use_mps_if_prosible(self): if torch.backends.mps.is_available(): device = torch.device("mps") - self.model = self.model.to(devoce) + self.model = self.model.to(device) def get_session_key(self): return self.model_id + ':messages' diff --git a/llmengine/base_embedding.py b/llmengine/base_embedding.py index 342945c..7a2c91f 100644 --- a/llmengine/base_embedding.py +++ b/llmengine/base_embedding.py @@ -1,9 +1,28 @@ +import torch + +model_pathMap = { +} +def llm_register(model_key, Klass): + global model_pathMap + model_pathMap[model_key] = Klass + +def get_llm_class(model_path): + for k,klass in model_pathMap.items(): + if len(model_path.split(k)) > 1: + return klass + print(f'{model_pathMap=}') + return None class BaseEmbedding: + def use_mps_if_prosible(self): + if torch.backends.mps.is_available(): + device = torch.device("mps") + self.model = self.model.to(device) + def embedding(self, doc): - es = self.model.encode([doc]) - return es[0] + es = self.model.encode([doc])[0] + return es.tolist() def similarity(self, qvector, dcovectors): s = self.model.similarity([qvector], docvectors) diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 6c71e7e..d5567e8 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -1,8 +1,12 @@ - import torch classs BaseReranker: + def use_mps_if_prosible(self): + if torch.backends.mps.is_available(): + device = torch.device("mps") + self.model = self.model.to(device) + def process_input(self, pairs): inputs = self.tokenizer( pairs, padding=False, truncation='longest_first', diff --git a/llmengine/client/llmclient b/llmengine/client/llmclient index 8fdc96f..f591efb 100755 --- a/llmengine/client/llmclient +++ b/llmengine/client/llmclient @@ -37,7 +37,7 @@ async def main(): } i = 0 buffer = '' - reco = hc('POST', args.url, headers=headers, data=json.dumps(d), timeout=3600) + reco = hc('POST', args.url, headers=headers, data=json.dumps(d)) async for chunk in liner(reco): chunk = chunk[6:] if chunk != '[DONE]': diff --git a/llmengine/embedding.py b/llmengine/embedding.py new file mode 100644 index 0000000..41f3ebf --- /dev/null +++ b/llmengine/embedding.py @@ -0,0 +1,51 @@ +from traceback import format_exc +import os +import sys +import argparse +from llmengine.qwen3embedding import * +from llmengine.base_embedding import get_llm_class + +from appPublic.registerfunction import RegisterFunction +from appPublic.worker import awaitify +from appPublic.log import debug, exception +from ahserver.serverenv import ServerEnv +from ahserver.globalEnv import stream_response +from ahserver.webapp import webserver + +from aiohttp_session import get_session + +def init(): + rf = RegisterFunction() + rf.register('embedding', embedding) + +async def embedding(request, params_kw, *params, **kw): + debug(f'{params_kw.doc=}') + se = ServerEnv() + engine = se.engine + f = awaitify(engine.embedding) + arr = await f(params_kw.doc) + debug(f'{arr=}, type(arr)') + return arr + +def main(): + parser = argparse.ArgumentParser(prog="Embedding") + parser.add_argument('-w', '--workdir') + parser.add_argument('-p', '--port') + parser.add_argument('model_path') + args = parser.parse_args() + Klass = get_llm_class(args.model_path) + if Klass is None: + e = Exception(f'{args.model_path} has not mapping to a model class') + exception(f'{e}, {format_exc()}') + raise e + se = ServerEnv() + se.engine = Klass(args.model_path) + se.engine.use_mps_if_prosible() + workdir = args.workdir or os.getcwd() + port = args.port + debug(f'{args=}') + webserver(init, workdir, port) + +if __name__ == '__main__': + main() + diff --git a/llmengine/qwen3.py b/llmengine/qwen3.py index 7e7c24a..a6585d5 100644 --- a/llmengine/qwen3.py +++ b/llmengine/qwen3.py @@ -19,7 +19,7 @@ class Qwen3LLM(T2TChatLLM): ) if torch.backends.mps.is_available(): device = torch.device("mps") - self.model = self.model.to(devoce) + self.model = self.model.to(device) self.model_id = model_id def build_kwargs(self, inputs, streamer): @@ -45,7 +45,9 @@ class Qwen3LLM(T2TChatLLM): llm_register("Qwen/Qwen3", Qwen3LLM) if __name__ == '__main__': - q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B') + import sys + model_path = sys.argv[1] + q3 = Qwen3LLM(model_path) session = {} while True: print('input prompt') @@ -54,10 +56,13 @@ if __name__ == '__main__': if p == 'q': break; for d in q3.stream_generate(session, p): + print(d) + """ if not d['done']: print(d['text'], end='', flush=True) else: x = {k:v for k,v in d.items() if k != 'text'} print(f'\n{x}\n') + """ diff --git a/llmengine/qwen3embedding.py b/llmengine/qwen3embedding.py index ecc978d..9daa664 100644 --- a/llmengine/qwen3embedding.py +++ b/llmengine/qwen3embedding.py @@ -2,7 +2,7 @@ # Requires sentence-transformers>=2.7.0 from sentence_transformers import SentenceTransformer -from llmengine.base_embedding import BaseEmbedding +from llmengine.base_embedding import BaseEmbedding, llm_register class Qwen3Embedding(BaseEmbedding): def __init__(self, model_id, max_length=8096): @@ -17,3 +17,4 @@ class Qwen3Embedding(BaseEmbedding): # ) self.max_length = max_length +llm_register('Qwen3-Embedding', Qwen3Embedding) diff --git a/llmengine/server.py b/llmengine/server.py index b18b600..568f5ba 100644 --- a/llmengine/server.py +++ b/llmengine/server.py @@ -2,14 +2,11 @@ from traceback import format_exc import os import sys import argparse -from llmengine.base_chat_llm import get_llm_class -from llmengine.gemma3_it import Gemma3LLM -from llmengine.qwen3 import Qwen3LLM -from llmengine.medgemma3_it import MedgemmaLLM -from llmengine.devstral import DevstralLLM +from llmengine.base_embedding import get_llm_class +from llmengine.qwen3embedding import Qwen3Embedding from appPublic.registerfunction import RegisterFunction -from appPublic.log import debug +from appPublic.log import debug, exception from ahserver.serverenv import ServerEnv from ahserver.globalEnv import stream_response from ahserver.webapp import webserver @@ -20,7 +17,7 @@ def init(): rf = RegisterFunction() rf.register('chat_completions', chat_completions) -async def chat_completions(request, params_kw, *params, **kw): +async def embedding(request, params_kw, *params, **kw): async def gor(): se = ServerEnv() engine = se.chat_engine @@ -47,12 +44,12 @@ def main(): args = parser.parse_args() Klass = get_llm_class(args.model_path) if Klass is None: - e = Exception(f'{model_path} has not mapping to a model class') + e = Exception(f'{args.model_path} has not mapping to a model class') exception(f'{e}, {format_exc()}') raise e se = ServerEnv() - se.chat_engine = Klass(args.model_path) - se.chat_engine.use_mps_if_prosible() + se.engine = Klass(args.model_path) + se.engine.use_mps_if_prosible() workdir = args.workdir or os.getcwd() port = args.port webserver(init, workdir, port) diff --git a/test/qwen3.sh b/test/qwen3.sh index 624fca5..65e210b 100755 --- a/test/qwen3.sh +++ b/test/qwen3.sh @@ -1,3 +1,3 @@ #!/usr/bin/bash -CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3 +~/models/tsfm.env/bin/python -m llmengine.server ~/models/Qwen/Qwen3-0.6B