bugfix
This commit is contained in:
parent
4b859e235b
commit
fa20715d63
@ -17,13 +17,14 @@ def get_llm_class(model_path):
|
|||||||
for k,klass in model_pathMap.items():
|
for k,klass in model_pathMap.items():
|
||||||
if len(model_path.split(k)) > 1:
|
if len(model_path.split(k)) > 1:
|
||||||
return klass
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class BaseChatLLM:
|
class BaseChatLLM:
|
||||||
def use_mps_if_prosible(self):
|
def use_mps_if_prosible(self):
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
self.model = self.model.to(devoce)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def get_session_key(self):
|
def get_session_key(self):
|
||||||
return self.model_id + ':messages'
|
return self.model_id + ':messages'
|
||||||
|
@ -1,9 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
model_pathMap = {
|
||||||
|
}
|
||||||
|
def llm_register(model_key, Klass):
|
||||||
|
global model_pathMap
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_llm_class(model_path):
|
||||||
|
for k,klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
class BaseEmbedding:
|
class BaseEmbedding:
|
||||||
|
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def embedding(self, doc):
|
def embedding(self, doc):
|
||||||
es = self.model.encode([doc])
|
es = self.model.encode([doc])[0]
|
||||||
return es[0]
|
return es.tolist()
|
||||||
|
|
||||||
def similarity(self, qvector, dcovectors):
|
def similarity(self, qvector, dcovectors):
|
||||||
s = self.model.similarity([qvector], docvectors)
|
s = self.model.similarity([qvector], docvectors)
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
classs BaseReranker:
|
classs BaseReranker:
|
||||||
|
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def process_input(self, pairs):
|
def process_input(self, pairs):
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
pairs, padding=False, truncation='longest_first',
|
pairs, padding=False, truncation='longest_first',
|
||||||
|
@ -37,7 +37,7 @@ async def main():
|
|||||||
}
|
}
|
||||||
i = 0
|
i = 0
|
||||||
buffer = ''
|
buffer = ''
|
||||||
reco = hc('POST', args.url, headers=headers, data=json.dumps(d), timeout=3600)
|
reco = hc('POST', args.url, headers=headers, data=json.dumps(d))
|
||||||
async for chunk in liner(reco):
|
async for chunk in liner(reco):
|
||||||
chunk = chunk[6:]
|
chunk = chunk[6:]
|
||||||
if chunk != '[DONE]':
|
if chunk != '[DONE]':
|
||||||
|
51
llmengine/embedding.py
Normal file
51
llmengine/embedding.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.qwen3embedding import *
|
||||||
|
from llmengine.base_embedding import get_llm_class
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('embedding', embedding)
|
||||||
|
|
||||||
|
async def embedding(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.doc=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.embedding)
|
||||||
|
arr = await f(params_kw.doc)
|
||||||
|
debug(f'{arr=}, type(arr)')
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="Embedding")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
se.engine.use_mps_if_prosible()
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
@ -19,7 +19,7 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
)
|
)
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
self.model = self.model.to(devoce)
|
self.model = self.model.to(device)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
@ -45,7 +45,9 @@ class Qwen3LLM(T2TChatLLM):
|
|||||||
llm_register("Qwen/Qwen3", Qwen3LLM)
|
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
q3 = Qwen3LLM('/share/models/Qwen/Qwen3-32B')
|
import sys
|
||||||
|
model_path = sys.argv[1]
|
||||||
|
q3 = Qwen3LLM(model_path)
|
||||||
session = {}
|
session = {}
|
||||||
while True:
|
while True:
|
||||||
print('input prompt')
|
print('input prompt')
|
||||||
@ -54,10 +56,13 @@ if __name__ == '__main__':
|
|||||||
if p == 'q':
|
if p == 'q':
|
||||||
break;
|
break;
|
||||||
for d in q3.stream_generate(session, p):
|
for d in q3.stream_generate(session, p):
|
||||||
|
print(d)
|
||||||
|
"""
|
||||||
if not d['done']:
|
if not d['done']:
|
||||||
print(d['text'], end='', flush=True)
|
print(d['text'], end='', flush=True)
|
||||||
else:
|
else:
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
print(f'\n{x}\n')
|
print(f'\n{x}\n')
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# Requires sentence-transformers>=2.7.0
|
# Requires sentence-transformers>=2.7.0
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from llmengine.base_embedding import BaseEmbedding
|
from llmengine.base_embedding import BaseEmbedding, llm_register
|
||||||
|
|
||||||
class Qwen3Embedding(BaseEmbedding):
|
class Qwen3Embedding(BaseEmbedding):
|
||||||
def __init__(self, model_id, max_length=8096):
|
def __init__(self, model_id, max_length=8096):
|
||||||
@ -17,3 +17,4 @@ class Qwen3Embedding(BaseEmbedding):
|
|||||||
# )
|
# )
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
|
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
||||||
|
@ -2,14 +2,11 @@ from traceback import format_exc
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
from llmengine.base_chat_llm import get_llm_class
|
from llmengine.base_embedding import get_llm_class
|
||||||
from llmengine.gemma3_it import Gemma3LLM
|
from llmengine.qwen3embedding import Qwen3Embedding
|
||||||
from llmengine.qwen3 import Qwen3LLM
|
|
||||||
from llmengine.medgemma3_it import MedgemmaLLM
|
|
||||||
from llmengine.devstral import DevstralLLM
|
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug, exception
|
||||||
from ahserver.serverenv import ServerEnv
|
from ahserver.serverenv import ServerEnv
|
||||||
from ahserver.globalEnv import stream_response
|
from ahserver.globalEnv import stream_response
|
||||||
from ahserver.webapp import webserver
|
from ahserver.webapp import webserver
|
||||||
@ -20,7 +17,7 @@ def init():
|
|||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
rf.register('chat_completions', chat_completions)
|
rf.register('chat_completions', chat_completions)
|
||||||
|
|
||||||
async def chat_completions(request, params_kw, *params, **kw):
|
async def embedding(request, params_kw, *params, **kw):
|
||||||
async def gor():
|
async def gor():
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.chat_engine
|
engine = se.chat_engine
|
||||||
@ -47,12 +44,12 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
Klass = get_llm_class(args.model_path)
|
Klass = get_llm_class(args.model_path)
|
||||||
if Klass is None:
|
if Klass is None:
|
||||||
e = Exception(f'{model_path} has not mapping to a model class')
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
exception(f'{e}, {format_exc()}')
|
exception(f'{e}, {format_exc()}')
|
||||||
raise e
|
raise e
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
se.chat_engine = Klass(args.model_path)
|
se.engine = Klass(args.model_path)
|
||||||
se.chat_engine.use_mps_if_prosible()
|
se.engine.use_mps_if_prosible()
|
||||||
workdir = args.workdir or os.getcwd()
|
workdir = args.workdir or os.getcwd()
|
||||||
port = args.port
|
port = args.port
|
||||||
webserver(init, workdir, port)
|
webserver(init, workdir, port)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
#!/usr/bin/bash
|
#!/usr/bin/bash
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=2,3,4,5 /share/vllm-0.8.5/bin/python -m llmengine.qwen3
|
~/models/tsfm.env/bin/python -m llmengine.server ~/models/Qwen/Qwen3-0.6B
|
||||||
|
Loading…
Reference in New Issue
Block a user