Merge branch 'main' of https://git.kaiyuancloud.cn/yumoqing/llmengine
This commit is contained in:
commit
002624288e
@ -96,6 +96,7 @@ class BaseChatLLM:
|
|||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
|
debug(f'{messages=}')
|
||||||
return self.processor.apply_chat_template(
|
return self.processor.apply_chat_template(
|
||||||
messages, add_generation_prompt=True,
|
messages, add_generation_prompt=True,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
from appPublic.log import debug
|
||||||
import os
|
import os
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
model_pathMap = {}
|
model_pathMap = {}
|
||||||
|
|
||||||
|
|
||||||
def llm_register(model_key: str, Klass):
|
def llm_register(model_key: str, Klass):
|
||||||
"""Register a triplet extractor class for a given model key."""
|
"""Register a triplet extractor class for a given model key."""
|
||||||
global model_pathMap
|
global model_pathMap
|
||||||
model_pathMap[model_key] = Klass
|
model_pathMap[model_key] = Klass
|
||||||
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}")
|
debug(f"Registered {Klass.__name__} for model_key: {model_key}")
|
||||||
|
|
||||||
|
|
||||||
def get_llm_class(model_path: str):
|
def get_llm_class(model_path: str):
|
||||||
@ -20,7 +17,7 @@ def get_llm_class(model_path: str):
|
|||||||
for k, klass in model_pathMap.items():
|
for k, klass in model_pathMap.items():
|
||||||
if k in model_path:
|
if k in model_path:
|
||||||
return klass
|
return klass
|
||||||
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
|
debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -31,19 +28,19 @@ class BaseTripleExtractor(ABC):
|
|||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.model_name = os.path.basename(model_path)
|
self.model_name = os.path.basename(model_path)
|
||||||
self.model = None
|
self.model = None
|
||||||
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
|
debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
|
||||||
|
|
||||||
def use_mps_if_possible(self):
|
def use_mps_if_possible(self):
|
||||||
"""Select device (MPS, CUDA, or CPU)."""
|
"""Select device (MPS, CUDA, or CPU)."""
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
logger.debug("Using MPS device")
|
debug("Using MPS device")
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
logger.debug("Using CUDA device")
|
debug("Using CUDA device")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logger.debug("Using CPU device")
|
debug("Using CPU device")
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
return device
|
return device
|
||||||
@ -51,4 +48,4 @@ class BaseTripleExtractor(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def extract_triplets(self, text: str) -> list:
|
async def extract_triplets(self, text: str) -> list:
|
||||||
"""Extract triplets from text."""
|
"""Extract triplets from text."""
|
||||||
pass
|
pass
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
from appPublic.worker import awaitify
|
from appPublic.worker import awaitify
|
||||||
from appPublic.log import debug
|
from appPublic.log import debug
|
||||||
from ahserver.serverenv import get_serverenv
|
from ahserver.serverenv import get_serverenv
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
from llmengine.base_chat_llm import BaseChatLLM, llm_register
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
class Qwen3LLM(BaseChatLLM):
|
class Qwen3LLM(BaseChatLLM):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
@ -17,9 +17,6 @@ class Qwen3LLM(BaseChatLLM):
|
|||||||
torch_dtype="auto",
|
torch_dtype="auto",
|
||||||
device_map="auto"
|
device_map="auto"
|
||||||
)
|
)
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
def build_kwargs(self, inputs, streamer):
|
||||||
@ -33,7 +30,7 @@ class Qwen3LLM(BaseChatLLM):
|
|||||||
return generate_kwargs
|
return generate_kwargs
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
def _messages2inputs(self, messages):
|
||||||
debug(f'{messages=}')
|
debug(f'-----------{messages=}-----------')
|
||||||
text = self.tokenizer.apply_chat_template(
|
text = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
@ -43,26 +40,3 @@ class Qwen3LLM(BaseChatLLM):
|
|||||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
llm_register("Qwen/Qwen3", Qwen3LLM)
|
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import sys
|
|
||||||
model_path = sys.argv[1]
|
|
||||||
q3 = Qwen3LLM(model_path)
|
|
||||||
session = {}
|
|
||||||
while True:
|
|
||||||
print('input prompt')
|
|
||||||
p = input()
|
|
||||||
if p:
|
|
||||||
if p == 'q':
|
|
||||||
break;
|
|
||||||
for d in q3.stream_generate(session, p):
|
|
||||||
print(d)
|
|
||||||
"""
|
|
||||||
if not d['done']:
|
|
||||||
print(d['text'], end='', flush=True)
|
|
||||||
else:
|
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
|
||||||
print(f'\n{x}\n')
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,30 +7,25 @@ from typing import List
|
|||||||
from base_triple import get_llm_class
|
from base_triple import get_llm_class
|
||||||
from mrebeltriple import MRebelTripleExtractor
|
from mrebeltriple import MRebelTripleExtractor
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
from appPublic.log import debug, exception
|
from appPublic.log import debug, exception, error, info
|
||||||
|
from appPublic.jsonConfig import getConfig
|
||||||
from ahserver.serverenv import ServerEnv
|
from ahserver.serverenv import ServerEnv
|
||||||
from ahserver.globalEnv import stream_response
|
from ahserver.globalEnv import stream_response
|
||||||
from ahserver.webapp import webserver
|
from ahserver.webapp import webserver
|
||||||
import aiohttp.web
|
import aiohttp.web
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = logging.getLogger('llmengine_triple')
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
|
|
||||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
|
||||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
# 加载配置文件
|
# 加载配置文件
|
||||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
|
||||||
try:
|
def load_milvus_config():
|
||||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
config = getConfig()
|
||||||
config = yaml.safe_load(f)
|
milvus_config = config.mivus_config_path
|
||||||
except Exception as e:
|
try:
|
||||||
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
raise RuntimeError(f"Failed to load config: {str(e)}")
|
milvus_config = yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
|
||||||
|
raise RuntimeError(f"Failed to load config: {str(e)}")
|
||||||
|
|
||||||
helptext = """mREBEL Triplets API:
|
helptext = """mREBEL Triplets API:
|
||||||
|
|
||||||
@ -79,7 +74,7 @@ async def triples(request, params_kw, *params, **kw):
|
|||||||
params_kw = data
|
params_kw = data
|
||||||
debug(f'Parsed JSON data: {params_kw}')
|
debug(f'Parsed JSON data: {params_kw}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to parse JSON: {str(e)}")
|
error(f"Failed to parse JSON: {str(e)}")
|
||||||
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
|
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
|
||||||
|
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
@ -100,8 +95,8 @@ async def triples(request, params_kw, *params, **kw):
|
|||||||
"data": triplets
|
"data": triplets
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in triples endpoint: {str(e)}")
|
error(f"Error in triples endpoint: {str(e)}")
|
||||||
logger.debug(f"Traceback: {format_exc()}")
|
debug(f"Traceback: {format_exc()}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -122,13 +117,11 @@ def main():
|
|||||||
se.engine = Klass(args.model_path)
|
se.engine = Klass(args.model_path)
|
||||||
workdir = args.workdir or os.getcwd()
|
workdir = args.workdir or os.getcwd()
|
||||||
port = args.port
|
port = args.port
|
||||||
debug(f'{args=}')
|
|
||||||
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
|
|
||||||
webserver(init, workdir, port)
|
webserver(init, workdir, port)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start server: {str(e)}")
|
error(f"Failed to start server: {str(e)}")
|
||||||
logger.debug(f"Traceback: {format_exc()}")
|
debug(f"Traceback: {format_exc()}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
14
test/connection/connection.service
Normal file
14
test/connection/connection.service
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[Unit]
|
||||||
|
Wants=systemd-networkd.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=forking
|
||||||
|
WorkingDirectory=/share/run/connection
|
||||||
|
ExecStart=/share/run/connection/start.sh
|
||||||
|
ExecStop=/share/run/connection/stop.sh
|
||||||
|
StandardOutput=append:/var/log/connection/connection.log
|
||||||
|
StandardError=append:/var/log/connection/connection.log
|
||||||
|
SyslogIdentifier=/share/run/connection
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
14
test/entities/entities.service
Normal file
14
test/entities/entities.service
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[Unit]
|
||||||
|
Wants=systemd-networkd.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
WorkingDirectory=/share/run/entities
|
||||||
|
ExecStart=/share/run/entities/start.sh
|
||||||
|
ExecStop=/share/run/entities/stop.sh
|
||||||
|
StandardOutput=append:/var/log/entities/entities.log
|
||||||
|
StandardError=append:/var/log/entities/entities.log
|
||||||
|
SyslogIdentifier=entities
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
|
14
test/triples/triples.service
Normal file
14
test/triples/triples.service
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[Unit]
|
||||||
|
Wants=systemd-networkd.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=forking
|
||||||
|
WorkingDirectory=/share/run/triples
|
||||||
|
ExecStart=/share/run/triples/start.sh
|
||||||
|
ExecStop=/share/run/triples/stop.sh
|
||||||
|
StandardOutput=append:/var/log/triples/triples.log
|
||||||
|
StandardError=append:/var/log/triples/triples.log
|
||||||
|
SyslogIdentifier=/share/run/triples
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
Loading…
Reference in New Issue
Block a user