This commit is contained in:
yumoqing 2025-06-20 18:56:30 +08:00
parent 4b859e235b
commit fa20715d63
9 changed files with 97 additions and 19 deletions

View File

@ -17,13 +17,14 @@ def get_llm_class(model_path):
for k,klass in model_pathMap.items(): for k,klass in model_pathMap.items():
if len(model_path.split(k)) > 1: if len(model_path.split(k)) > 1:
return klass return klass
print(f'{model_pathMap=}')
return None return None
class BaseChatLLM: class BaseChatLLM:
def use_mps_if_prosible(self): def use_mps_if_prosible(self):
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device("mps") device = torch.device("mps")
self.model = self.model.to(devoce) self.model = self.model.to(device)
def get_session_key(self): def get_session_key(self):
return self.model_id + ':messages' return self.model_id + ':messages'

View File

@ -1,9 +1,28 @@
import torch
model_pathMap = {
}
def llm_register(model_key, Klass):
global model_pathMap
model_pathMap[model_key] = Klass
def get_llm_class(model_path):
for k,klass in model_pathMap.items():
if len(model_path.split(k)) > 1:
return klass
print(f'{model_pathMap=}')
return None
class BaseEmbedding: class BaseEmbedding:
def use_mps_if_prosible(self):
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
def embedding(self, doc): def embedding(self, doc):
es = self.model.encode([doc]) es = self.model.encode([doc])[0]
return es[0] return es.tolist()
def similarity(self, qvector, dcovectors): def similarity(self, qvector, dcovectors):
s = self.model.similarity([qvector], docvectors) s = self.model.similarity([qvector], docvectors)

View File

@ -1,8 +1,12 @@
import torch import torch
classs BaseReranker: classs BaseReranker:
def use_mps_if_prosible(self):
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
def process_input(self, pairs): def process_input(self, pairs):
inputs = self.tokenizer( inputs = self.tokenizer(
pairs, padding=False, truncation='longest_first', pairs, padding=False, truncation='longest_first',

View File

@ -37,7 +37,7 @@ async def main():
} }
i = 0 i = 0
buffer = '' buffer = ''
reco = hc('POST', args.url, headers=headers, data=json.dumps(d), timeout=3600) reco = hc('POST', args.url, headers=headers, data=json.dumps(d))
async for chunk in liner(reco): async for chunk in liner(reco):
chunk = chunk[6:] chunk = chunk[6:]
if chunk != '[DONE]': if chunk != '[DONE]':

51
llmengine/embedding.py Normal file
View File

@ -0,0 +1,51 @@
from traceback import format_exc
import os
import sys
import argparse
from llmengine.qwen3embedding import *
from llmengine.base_embedding import get_llm_class
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify
from appPublic.log import debug, exception
from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
from aiohttp_session import get_session
def init():
rf = RegisterFunction()
rf.register('embedding', embedding)
async def embedding(request, params_kw, *params, **kw):
debug(f'{params_kw.doc=}')
se = ServerEnv()
engine = se.engine
f = awaitify(engine.embedding)
arr = await f(params_kw.doc)
debug(f'{arr=}, type(arr)')
return arr
def main():
parser = argparse.ArgumentParser(prog="Embedding")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
parser.add_argument('model_path')
args = parser.parse_args()
Klass = get_llm_class(args.model_path)
if Klass is None:
e = Exception(f'{args.model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}')
raise e
se = ServerEnv()
se.engine = Klass(args.model_path)
se.engine.use_mps_if_prosible()
workdir = args.workdir or os.getcwd()
port = args.port
debug(f'{args=}')
webserver(init, workdir, port)
if __name__ == '__main__':
main()

View File

@ -19,7 +19,7 @@ class Qwen3LLM(T2TChatLLM):
) )
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device("mps") device = torch.device("mps")
self.model = self.model.to(devoce) self.model = self.model.to(device)
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
@ -45,7 +45,9 @@ class Qwen3LLM(T2TChatLLM):
llm_register("Qwen/Qwen3", Qwen3LLM) llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__': if __name__ == '__main__':
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B') import sys
model_path = sys.argv[1]
q3 = Qwen3LLM(model_path)
session = {} session = {}
while True: while True:
print('input prompt') print('input prompt')
@ -54,10 +56,13 @@ if __name__ == '__main__':
if p == 'q': if p == 'q':
break; break;
for d in q3.stream_generate(session, p): for d in q3.stream_generate(session, p):
print(d)
"""
if not d['done']: if not d['done']:
print(d['text'], end='', flush=True) print(d['text'], end='', flush=True)
else: else:
x = {k:v for k,v in d.items() if k != 'text'} x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n') print(f'\n{x}\n')
"""

View File

@ -2,7 +2,7 @@
# Requires sentence-transformers>=2.7.0 # Requires sentence-transformers>=2.7.0
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from llmengine.base_embedding import BaseEmbedding from llmengine.base_embedding import BaseEmbedding, llm_register
class Qwen3Embedding(BaseEmbedding): class Qwen3Embedding(BaseEmbedding):
def __init__(self, model_id, max_length=8096): def __init__(self, model_id, max_length=8096):
@ -17,3 +17,4 @@ class Qwen3Embedding(BaseEmbedding):
# ) # )
self.max_length = max_length self.max_length = max_length
llm_register('Qwen3-Embedding', Qwen3Embedding)

View File

@ -2,14 +2,11 @@ from traceback import format_exc
import os import os
import sys import sys
import argparse import argparse
from llmengine.base_chat_llm import get_llm_class from llmengine.base_embedding import get_llm_class
from llmengine.gemma3_it import Gemma3LLM from llmengine.qwen3embedding import Qwen3Embedding
from llmengine.qwen3 import Qwen3LLM
from llmengine.medgemma3_it import MedgemmaLLM
from llmengine.devstral import DevstralLLM
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug from appPublic.log import debug, exception
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver from ahserver.webapp import webserver
@ -20,7 +17,7 @@ def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('chat_completions', chat_completions) rf.register('chat_completions', chat_completions)
async def chat_completions(request, params_kw, *params, **kw): async def embedding(request, params_kw, *params, **kw):
async def gor(): async def gor():
se = ServerEnv() se = ServerEnv()
engine = se.chat_engine engine = se.chat_engine
@ -47,12 +44,12 @@ def main():
args = parser.parse_args() args = parser.parse_args()
Klass = get_llm_class(args.model_path) Klass = get_llm_class(args.model_path)
if Klass is None: if Klass is None:
e = Exception(f'{model_path} has not mapping to a model class') e = Exception(f'{args.model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}') exception(f'{e}, {format_exc()}')
raise e raise e
se = ServerEnv() se = ServerEnv()
se.chat_engine = Klass(args.model_path) se.engine = Klass(args.model_path)
se.chat_engine.use_mps_if_prosible() se.engine.use_mps_if_prosible()
workdir = args.workdir or os.getcwd() workdir = args.workdir or os.getcwd()
port = args.port port = args.port
webserver(init, workdir, port) webserver(init, workdir, port)

View File

@ -1,3 +1,3 @@
#!/usr/bin/bash #!/usr/bin/bash
CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3 ~/models/tsfm.env/bin/python -m llmengine.server ~/models/Qwen/Qwen3-0.6B