Merge branch 'main' of git.kaiyuancloud.cn:yumoqing/llmengine
This commit is contained in:
ymq1 2025-06-24 19:44:53 +08:00
commit 5f1b06d10f
7 changed files with 405 additions and 0 deletions

54
llmengine/base_triple.py Normal file
View File

@ -0,0 +1,54 @@
import torch
from abc import ABC, abstractmethod
import logging
import os
logger = logging.getLogger(__name__)
model_pathMap = {}
def llm_register(model_key: str, Klass):
"""Register a triplet extractor class for a given model key."""
global model_pathMap
model_pathMap[model_key] = Klass
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}")
def get_llm_class(model_path: str):
"""Return the triplet extractor class for the given model path."""
for k, klass in model_pathMap.items():
if k in model_path:
return klass
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
return None
class BaseTripleExtractor(ABC):
"""Base class for triplet extraction."""
def __init__(self, model_path: str):
self.model_path = model_path
self.model_name = os.path.basename(model_path)
self.model = None
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
def use_mps_if_possible(self):
"""Select device (MPS, CUDA, or CPU)."""
if torch.backends.mps.is_available():
device = torch.device("mps")
logger.debug("Using MPS device")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.debug("Using CUDA device")
else:
device = torch.device("cpu")
logger.debug("Using CPU device")
if self.model is not None:
self.model = self.model.to(device)
return device
@abstractmethod
async def extract_triplets(self, text: str) -> list:
"""Extract triplets from text."""
pass

161
llmengine/mrebeltriple.py Normal file
View File

@ -0,0 +1,161 @@
import os
import torch
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
from base_triple import BaseTripleExtractor, llm_register
logger = logging.getLogger(__name__)
class MRebelTripleExtractor(BaseTripleExtractor):
def __init__(self, model_path: str):
super().__init__(model_path)
try:
logger.debug(f"Loading tokenizer from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.debug(f"Loading model from {model_path}")
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.device = self.use_mps_if_possible()
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
except Exception as e:
logger.error(f"Failed to load mREBEL model: {str(e)}")
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
self.gen_kwargs = {
"max_length": 512,
"min_length": 10,
"length_penalty": 0.5,
"num_beams": 3,
"num_return_sequences": 1,
"no_repeat_ngram_size": 2,
"early_stopping": True,
"decoder_start_token_id": self.triplet_id,
}
def split_document(self, text: str, max_chunk_size: int = 150) -> list:
"""Split document into semantic chunks."""
sentences = re.split(r'(?<=[。!?;\n])', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
logger.debug(f"Text split into: {len(chunks)} chunks")
return chunks
def extract_triplets_typed(self, text: str) -> list:
"""Parse mREBEL generated text for triplets."""
triplets = []
logger.debug(f"Raw generated text: {text}")
tokens = []
in_tag = False
buffer = ""
for char in text:
if char == '<':
in_tag = True
if buffer:
tokens.append(buffer.strip())
buffer = ""
buffer += char
elif char == '>':
in_tag = False
buffer += char
tokens.append(buffer.strip())
buffer = ""
else:
buffer += char
if buffer:
tokens.append(buffer.strip())
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
tokens = [t for t in tokens if t not in special_tokens and t]
logger.debug(f"Processed tokens: {tokens}")
i = 0
while i < len(tokens):
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
entity1 = tokens[i + 1]
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
entity2 = tokens[i + 3]
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
relation = tokens[i + 5]
if entity1 and type1 and entity2 and type2 and relation:
triplets.append({
'head': entity1.strip(),
'head_type': type1,
'type': relation.strip(),
'tail': entity2.strip(),
'tail_type': type2
})
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
i += 6
else:
i += 1
return triplets
async def extract_triplets(self, text: str) -> list:
"""Extract triplets from text and return unique triplets."""
try:
if not text:
raise ValueError("Text cannot be empty")
text_chunks = self.split_document(text, max_chunk_size=150)
logger.debug(f"Text split into {len(text_chunks)} chunks")
all_triplets = []
for i, chunk in enumerate(text_chunks):
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
model_inputs = self.tokenizer(
chunk,
max_length=256,
padding=True,
truncation=True,
return_tensors="pt"
).to(self.device)
try:
generated_tokens = self.model.generate(
model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
**self.gen_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for idx, sentence in enumerate(decoded_preds):
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
triplets = self.extract_triplets_typed(sentence)
if triplets:
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
all_triplets.extend(triplets)
except Exception as e:
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
continue
unique_triplets = []
seen = set()
for t in all_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_triplets.append(t)
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
return unique_triplets
except Exception as e:
logger.error(f"Failed to extract triplets: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
return []
llm_register("mrebel-large", MRebelTripleExtractor)

134
llmengine/triple.py Normal file
View File

@ -0,0 +1,134 @@
from traceback import format_exc
import os
import argparse
import logging
import yaml
from typing import List
from base_triple import get_llm_class
from mrebeltriple import MRebelTripleExtractor
from appPublic.registerfunction import RegisterFunction
from appPublic.log import debug, exception
from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
import aiohttp.web
# 配置日志
logger = logging.getLogger('llmengine_triple')
logger.setLevel(logging.DEBUG)
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
os.makedirs(os.path.dirname(log_file), exist_ok=True)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 加载配置文件
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
try:
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
except Exception as e:
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
raise RuntimeError(f"Failed to load config: {str(e)}")
helptext = """mREBEL Triplets API:
1. Triplets Endpoint:
path: /v1/triples
headers: {
"Content-Type": "application/json"
}
data: {
"text": "知识图谱是一个结构化的语义知识库。"
}
response: {
"object": "list",
"data": [
{
"head": "知识图谱",
"head_type": "Concept",
"type": "is_a",
"tail": "语义知识库",
"tail_type": "Concept"
},
...
]
}
2. Docs Endpoint:
path: /v1/docs
response: This help text
"""
def init():
rf = RegisterFunction()
rf.register('triples', triples)
rf.register('docs', docs)
async def docs(request, params_kw, *params, **kw):
return helptext
async def triples(request, params_kw, *params, **kw):
try:
debug(f'Received params_kw: {params_kw}')
# 显式解析请求数据
if not params_kw:
try:
data = await request.json()
params_kw = data
debug(f'Parsed JSON data: {params_kw}')
except Exception as e:
logger.error(f"Failed to parse JSON: {str(e)}")
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
se = ServerEnv()
engine = se.engine
if engine is None:
raise ValueError("Engine not initialized")
text = params_kw.get('text')
if not text:
e = ValueError("text cannot be empty")
exception(f'{e}')
raise e
triplets = await engine.extract_triplets(text)
debug(f'{triplets=}')
return {
"object": "list",
"data": triplets
}
except Exception as e:
logger.error(f"Error in triples endpoint: {str(e)}")
logger.debug(f"Traceback: {format_exc()}")
raise
def main():
parser = argparse.ArgumentParser(prog="mREBEL Triplet Service")
parser.add_argument('-w', '--workdir', default=None)
parser.add_argument('-p', '--port', type=int, default=9991)
parser.add_argument('model_path')
args = parser.parse_args()
try:
Klass = get_llm_class(args.model_path)
if Klass is None:
e = Exception(f"{args.model_path} has no mapping to a model class")
exception(f'{e}, {format_exc()}')
raise e
se = ServerEnv()
se.engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd()
port = args.port
debug(f'{args=}')
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
webserver(init, workdir, port)
except Exception as e:
logger.error(f"Failed to start server: {str(e)}")
logger.debug(f"Traceback: {format_exc()}")
raise
if __name__ == "__main__":
main()

View File

@ -0,0 +1,50 @@
{
"filesroot":"$[workdir]$/files",
"logger":{
"name":"llmengine",
"levelname":"info",
"logfile":"$[workdir]$/logs/llmengine.log"
},
"website":{
"paths":[
["$[workdir]$/wwwroot",""]
],
"client_max_size":10000,
"host":"0.0.0.0",
"port":9991,
"coding":"utf-8",
"indexes":[
"index.html",
"index.ui"
],
"startswiths":[
{
"leading":"/idfile",
"registerfunction":"idfile"
},{
"leading": "/v1/triples",
"registerfunction": "triples"
},{
"leading": "/docs",
"registerfunction": "docs"
}
],
"processors":[
[".tmpl","tmpl"],
[".app","app"],
[".ui","bui"],
[".dspy","dspy"],
[".md","md"]
],
"rsakey_oops":{
"privatekey":"$[workdir]$/conf/rsa_private_key.pem",
"publickey":"$[workdir]$/conf/rsa_public_key.pem"
},
"session_max_time":3000,
"session_issue_time":2500,
"session_redis_notuse":{
"url":"redis://127.0.0.1:6379"
}
}
}

View File

3
test/triples/start.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.triple -p 9991 /share/models/Babelscape/mrebel-large

3
test/triples/stop.sh Executable file
View File

@ -0,0 +1,3 @@
#!/bin/bash
pkill -f "llmengine.triple.*9991"