This commit is contained in:
wangmeihua 2025-07-03 18:40:04 +08:00
commit 002624288e
7 changed files with 71 additions and 64 deletions

View File

@ -96,6 +96,7 @@ class BaseChatLLM:
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}')
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
messages, add_generation_prompt=True, messages, add_generation_prompt=True,
tokenize=True, tokenize=True,

View File

@ -1,18 +1,15 @@
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging from appPublic.log import debug
import os import os
logger = logging.getLogger(__name__)
model_pathMap = {} model_pathMap = {}
def llm_register(model_key: str, Klass): def llm_register(model_key: str, Klass):
"""Register a triplet extractor class for a given model key.""" """Register a triplet extractor class for a given model key."""
global model_pathMap global model_pathMap
model_pathMap[model_key] = Klass model_pathMap[model_key] = Klass
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}") debug(f"Registered {Klass.__name__} for model_key: {model_key}")
def get_llm_class(model_path: str): def get_llm_class(model_path: str):
@ -20,7 +17,7 @@ def get_llm_class(model_path: str):
for k, klass in model_pathMap.items(): for k, klass in model_pathMap.items():
if k in model_path: if k in model_path:
return klass return klass
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}") debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
return None return None
@ -31,19 +28,19 @@ class BaseTripleExtractor(ABC):
self.model_path = model_path self.model_path = model_path
self.model_name = os.path.basename(model_path) self.model_name = os.path.basename(model_path)
self.model = None self.model = None
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}") debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
def use_mps_if_possible(self): def use_mps_if_possible(self):
"""Select device (MPS, CUDA, or CPU).""" """Select device (MPS, CUDA, or CPU)."""
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device("mps") device = torch.device("mps")
logger.debug("Using MPS device") debug("Using MPS device")
elif torch.cuda.is_available(): elif torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
logger.debug("Using CUDA device") debug("Using CUDA device")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
logger.debug("Using CPU device") debug("Using CPU device")
if self.model is not None: if self.model is not None:
self.model = self.model.to(device) self.model = self.model.to(device)
return device return device

View File

@ -4,10 +4,10 @@
from appPublic.worker import awaitify from appPublic.worker import awaitify
from appPublic.log import debug from appPublic.log import debug
from ahserver.serverenv import get_serverenv from ahserver.serverenv import get_serverenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image from PIL import Image
import torch import torch
from llmengine.base_chat_llm import BaseChatLLM, llm_register from llmengine.base_chat_llm import BaseChatLLM, llm_register
from transformers import AutoModelForCausalLM, AutoTokenizer
class Qwen3LLM(BaseChatLLM): class Qwen3LLM(BaseChatLLM):
def __init__(self, model_id): def __init__(self, model_id):
@ -17,9 +17,6 @@ class Qwen3LLM(BaseChatLLM):
torch_dtype="auto", torch_dtype="auto",
device_map="auto" device_map="auto"
) )
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
self.model_id = model_id self.model_id = model_id
def build_kwargs(self, inputs, streamer): def build_kwargs(self, inputs, streamer):
@ -33,7 +30,7 @@ class Qwen3LLM(BaseChatLLM):
return generate_kwargs return generate_kwargs
def _messages2inputs(self, messages): def _messages2inputs(self, messages):
debug(f'{messages=}') debug(f'-----------{messages=}-----------')
text = self.tokenizer.apply_chat_template( text = self.tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
@ -43,26 +40,3 @@ class Qwen3LLM(BaseChatLLM):
return self.tokenizer([text], return_tensors="pt").to(self.model.device) return self.tokenizer([text], return_tensors="pt").to(self.model.device)
llm_register("Qwen/Qwen3", Qwen3LLM) llm_register("Qwen/Qwen3", Qwen3LLM)
if __name__ == '__main__':
import sys
model_path = sys.argv[1]
q3 = Qwen3LLM(model_path)
session = {}
while True:
print('input prompt')
p = input()
if p:
if p == 'q':
break;
for d in q3.stream_generate(session, p):
print(d)
"""
if not d['done']:
print(d['text'], end='', flush=True)
else:
x = {k:v for k,v in d.items() if k != 'text'}
print(f'\n{x}\n')
"""

View File

@ -7,29 +7,24 @@ from typing import List
from base_triple import get_llm_class from base_triple import get_llm_class
from mrebeltriple import MRebelTripleExtractor from mrebeltriple import MRebelTripleExtractor
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception from appPublic.log import debug, exception, error, info
from appPublic.jsonConfig import getConfig
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver from ahserver.webapp import webserver
import aiohttp.web import aiohttp.web
# 配置日志 # 配置日志
logger = logging.getLogger('llmengine_triple')
logger.setLevel(logging.DEBUG)
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 加载配置文件 # 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try: def load_milvus_config():
config = getConfig()
milvus_config = config.mivus_config_path
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f: with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f) milvus_config = yaml.safe_load(f)
except Exception as e: except Exception as e:
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}") error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
raise RuntimeError(f"Failed to load config: {str(e)}") raise RuntimeError(f"Failed to load config: {str(e)}")
helptext = """mREBEL Triplets API: helptext = """mREBEL Triplets API:
@ -79,7 +74,7 @@ async def triples(request, params_kw, *params, **kw):
params_kw = data params_kw = data
debug(f'Parsed JSON data: {params_kw}') debug(f'Parsed JSON data: {params_kw}')
except Exception as e: except Exception as e:
logger.error(f"Failed to parse JSON: {str(e)}") error(f"Failed to parse JSON: {str(e)}")
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}") raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
se = ServerEnv() se = ServerEnv()
@ -100,8 +95,8 @@ async def triples(request, params_kw, *params, **kw):
"data": triplets "data": triplets
} }
except Exception as e: except Exception as e:
logger.error(f"Error in triples endpoint: {str(e)}") error(f"Error in triples endpoint: {str(e)}")
logger.debug(f"Traceback: {format_exc()}") debug(f"Traceback: {format_exc()}")
raise raise
def main(): def main():
@ -122,12 +117,10 @@ def main():
se.engine = Klass(args.model_path) se.engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd() workdir = args.workdir or os.getcwd()
port = args.port port = args.port
debug(f'{args=}')
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
webserver(init, workdir, port) webserver(init, workdir, port)
except Exception as e: except Exception as e:
logger.error(f"Failed to start server: {str(e)}") error(f"Failed to start server: {str(e)}")
logger.debug(f"Traceback: {format_exc()}") debug(f"Traceback: {format_exc()}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
Type=forking
WorkingDirectory=/share/run/connection
ExecStart=/share/run/connection/start.sh
ExecStop=/share/run/connection/stop.sh
StandardOutput=append:/var/log/connection/connection.log
StandardError=append:/var/log/connection/connection.log
SyslogIdentifier=/share/run/connection
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
WorkingDirectory=/share/run/entities
ExecStart=/share/run/entities/start.sh
ExecStop=/share/run/entities/stop.sh
StandardOutput=append:/var/log/entities/entities.log
StandardError=append:/var/log/entities/entities.log
SyslogIdentifier=entities
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,14 @@
[Unit]
Wants=systemd-networkd.service
[Service]
Type=forking
WorkingDirectory=/share/run/triples
ExecStart=/share/run/triples/start.sh
ExecStop=/share/run/triples/stop.sh
StandardOutput=append:/var/log/triples/triples.log
StandardError=append:/var/log/triples/triples.log
SyslogIdentifier=/share/run/triples
[Install]
WantedBy=multi-user.target