This commit is contained in:
ymq1 2025-07-02 15:13:59 +08:00
parent 95060e3285
commit 9f18a84283
5 changed files with 68 additions and 36 deletions

View File

@ -1,18 +1,15 @@
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging from appPublic.log import debug
import os import os
logger = logging.getLogger(__name__)
model_pathMap = {} model_pathMap = {}
def llm_register(model_key: str, Klass): def llm_register(model_key: str, Klass):
"""Register a triplet extractor class for a given model key.""" """Register a triplet extractor class for a given model key."""
global model_pathMap global model_pathMap
model_pathMap[model_key] = Klass model_pathMap[model_key] = Klass
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}") debug(f"Registered {Klass.__name__} for model_key: {model_key}")
def get_llm_class(model_path: str): def get_llm_class(model_path: str):
@ -20,7 +17,7 @@ def get_llm_class(model_path: str):
for k, klass in model_pathMap.items(): for k, klass in model_pathMap.items():
if k in model_path: if k in model_path:
return klass return klass
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}") debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
return None return None
@ -31,19 +28,19 @@ class BaseTripleExtractor(ABC):
self.model_path = model_path self.model_path = model_path
self.model_name = os.path.basename(model_path) self.model_name = os.path.basename(model_path)
self.model = None self.model = None
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}") debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
def use_mps_if_possible(self): def use_mps_if_possible(self):
"""Select device (MPS, CUDA, or CPU).""" """Select device (MPS, CUDA, or CPU)."""
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device("mps") device = torch.device("mps")
logger.debug("Using MPS device") debug("Using MPS device")
elif torch.cuda.is_available(): elif torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
logger.debug("Using CUDA device") debug("Using CUDA device")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
logger.debug("Using CPU device") debug("Using CPU device")
if self.model is not None: if self.model is not None:
self.model = self.model.to(device) self.model = self.model.to(device)
return device return device
@ -51,4 +48,4 @@ class BaseTripleExtractor(ABC):
@abstractmethod @abstractmethod
async def extract_triplets(self, text: str) -> list: async def extract_triplets(self, text: str) -> list:
"""Extract triplets from text.""" """Extract triplets from text."""
pass pass

View File

@ -7,30 +7,25 @@ from typing import List
from base_triple import get_llm_class from base_triple import get_llm_class
from mrebeltriple import MRebelTripleExtractor from mrebeltriple import MRebelTripleExtractor
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception, error, info
from appPublic.jsonConfig import getConfig
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver from ahserver.webapp import webserver
import aiohttp.web import aiohttp.web
# 配置日志 # 配置日志
logger = logging.getLogger('llmengine_triple')
logger.setLevel(logging.DEBUG)
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 加载配置文件 # 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try: def load_milvus_config():
with open(CONFIG_PATH, 'r', encoding='utf-8') as f: config = getConfig()
config = yaml.safe_load(f) milvus_config = config.mivus_config_path
except Exception as e: try:
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}") with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
raise RuntimeError(f"Failed to load config: {str(e)}") milvus_config = yaml.safe_load(f)
except Exception as e:
error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
raise RuntimeError(f"Failed to load config: {str(e)}")
helptext = """mREBEL Triplets API: helptext = """mREBEL Triplets API:
@ -79,7 +74,7 @@ async def triples(request, params_kw, *params, **kw):
params_kw = data params_kw = data
debug(f'Parsed JSON data: {params_kw}') debug(f'Parsed JSON data: {params_kw}')
except Exception as e: except Exception as e:
logger.error(f"Failed to parse JSON: {str(e)}") error(f"Failed to parse JSON: {str(e)}")
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}") raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
se = ServerEnv() se = ServerEnv()
@ -100,8 +95,8 @@ async def triples(request, params_kw, *params, **kw):
"data": triplets "data": triplets
} }
except Exception as e: except Exception as e:
logger.error(f"Error in triples endpoint: {str(e)}") error(f"Error in triples endpoint: {str(e)}")
logger.debug(f"Traceback: {format_exc()}") debug(f"Traceback: {format_exc()}")
raise raise
def main(): def main():
@ -122,13 +117,11 @@ def main():
se.engine = Klass(args.model_path) se.engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd() workdir = args.workdir or os.getcwd()
port = args.port port = args.port
debug(f'{args=}')
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
webserver(init, workdir, port) webserver(init, workdir, port)
except Exception as e: except Exception as e:
logger.error(f"Failed to start server: {str(e)}") error(f"Failed to start server: {str(e)}")
logger.debug(f"Traceback: {format_exc()}") debug(f"Traceback: {format_exc()}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
Type=forking
WorkingDirectory=/share/run/connection
ExecStart=/share/run/connection/start.sh
ExecStop=/share/run/connection/stop.sh
StandardOutput=append:/var/log/connection/connection.log
StandardError=append:/var/log/connection/connection.log
SyslogIdentifier=/share/run/connection
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
WorkingDirectory=/share/run/entities
ExecStart=/share/run/entities/start.sh
ExecStop=/share/run/entities/stop.sh
StandardOutput=append:/var/log/entities/entities.log
StandardError=append:/var/log/entities/entities.log
SyslogIdentifier=entities
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
Type=forking
WorkingDirectory=/share/run/triples
ExecStart=/share/run/triples/start.sh
ExecStop=/share/run/triples/stop.sh
StandardOutput=append:/var/log/triples/triples.log
StandardError=append:/var/log/triples/triples.log
SyslogIdentifier=/share/run/triples
[Install]
WantedBy=multi-user.target