:x
Merge branch 'main' of git.kaiyuancloud.cn:yumoqing/llmengine
This commit is contained in:
commit
5f1b06d10f
54
llmengine/base_triple.py
Normal file
54
llmengine/base_triple.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_pathMap = {}
|
||||
|
||||
|
||||
def llm_register(model_key: str, Klass):
|
||||
"""Register a triplet extractor class for a given model key."""
|
||||
global model_pathMap
|
||||
model_pathMap[model_key] = Klass
|
||||
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}")
|
||||
|
||||
|
||||
def get_llm_class(model_path: str):
|
||||
"""Return the triplet extractor class for the given model path."""
|
||||
for k, klass in model_pathMap.items():
|
||||
if k in model_path:
|
||||
return klass
|
||||
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
|
||||
return None
|
||||
|
||||
|
||||
class BaseTripleExtractor(ABC):
|
||||
"""Base class for triplet extraction."""
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = model_path
|
||||
self.model_name = os.path.basename(model_path)
|
||||
self.model = None
|
||||
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
|
||||
|
||||
def use_mps_if_possible(self):
|
||||
"""Select device (MPS, CUDA, or CPU)."""
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
logger.debug("Using MPS device")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.debug("Using CUDA device")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.debug("Using CPU device")
|
||||
if self.model is not None:
|
||||
self.model = self.model.to(device)
|
||||
return device
|
||||
|
||||
@abstractmethod
|
||||
async def extract_triplets(self, text: str) -> list:
|
||||
"""Extract triplets from text."""
|
||||
pass
|
161
llmengine/mrebeltriple.py
Normal file
161
llmengine/mrebeltriple.py
Normal file
@ -0,0 +1,161 @@
|
||||
import os
|
||||
import torch
|
||||
import re
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
import logging
|
||||
from base_triple import BaseTripleExtractor, llm_register
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__(model_path)
|
||||
try:
|
||||
logger.debug(f"Loading tokenizer from {model_path}")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
logger.debug(f"Loading model from {model_path}")
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||||
self.device = self.use_mps_if_possible()
|
||||
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
|
||||
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load mREBEL model: {str(e)}")
|
||||
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
|
||||
|
||||
self.gen_kwargs = {
|
||||
"max_length": 512,
|
||||
"min_length": 10,
|
||||
"length_penalty": 0.5,
|
||||
"num_beams": 3,
|
||||
"num_return_sequences": 1,
|
||||
"no_repeat_ngram_size": 2,
|
||||
"early_stopping": True,
|
||||
"decoder_start_token_id": self.triplet_id,
|
||||
}
|
||||
|
||||
def split_document(self, text: str, max_chunk_size: int = 150) -> list:
|
||||
"""Split document into semantic chunks."""
|
||||
sentences = re.split(r'(?<=[。!?;\n])', text)
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
for sentence in sentences:
|
||||
if len(current_chunk) + len(sentence) <= max_chunk_size:
|
||||
current_chunk += sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
logger.debug(f"Text split into: {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
def extract_triplets_typed(self, text: str) -> list:
|
||||
"""Parse mREBEL generated text for triplets."""
|
||||
triplets = []
|
||||
logger.debug(f"Raw generated text: {text}")
|
||||
|
||||
tokens = []
|
||||
in_tag = False
|
||||
buffer = ""
|
||||
for char in text:
|
||||
if char == '<':
|
||||
in_tag = True
|
||||
if buffer:
|
||||
tokens.append(buffer.strip())
|
||||
buffer = ""
|
||||
buffer += char
|
||||
elif char == '>':
|
||||
in_tag = False
|
||||
buffer += char
|
||||
tokens.append(buffer.strip())
|
||||
buffer = ""
|
||||
else:
|
||||
buffer += char
|
||||
if buffer:
|
||||
tokens.append(buffer.strip())
|
||||
|
||||
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||||
tokens = [t for t in tokens if t not in special_tokens and t]
|
||||
logger.debug(f"Processed tokens: {tokens}")
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
|
||||
entity1 = tokens[i + 1]
|
||||
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
|
||||
entity2 = tokens[i + 3]
|
||||
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
|
||||
relation = tokens[i + 5]
|
||||
|
||||
if entity1 and type1 and entity2 and type2 and relation:
|
||||
triplets.append({
|
||||
'head': entity1.strip(),
|
||||
'head_type': type1,
|
||||
'type': relation.strip(),
|
||||
'tail': entity2.strip(),
|
||||
'tail_type': type2
|
||||
})
|
||||
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||
i += 6
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return triplets
|
||||
|
||||
async def extract_triplets(self, text: str) -> list:
|
||||
"""Extract triplets from text and return unique triplets."""
|
||||
try:
|
||||
if not text:
|
||||
raise ValueError("Text cannot be empty")
|
||||
|
||||
text_chunks = self.split_document(text, max_chunk_size=150)
|
||||
logger.debug(f"Text split into {len(text_chunks)} chunks")
|
||||
|
||||
all_triplets = []
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
chunk,
|
||||
max_length=256,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
try:
|
||||
generated_tokens = self.model.generate(
|
||||
model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
**self.gen_kwargs,
|
||||
)
|
||||
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||||
for idx, sentence in enumerate(decoded_preds):
|
||||
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
|
||||
triplets = self.extract_triplets_typed(sentence)
|
||||
if triplets:
|
||||
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
||||
all_triplets.extend(triplets)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
|
||||
continue
|
||||
|
||||
unique_triplets = []
|
||||
seen = set()
|
||||
for t in all_triplets:
|
||||
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_triplets.append(t)
|
||||
|
||||
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
|
||||
return unique_triplets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract triplets: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return []
|
||||
|
||||
llm_register("mrebel-large", MRebelTripleExtractor)
|
134
llmengine/triple.py
Normal file
134
llmengine/triple.py
Normal file
@ -0,0 +1,134 @@
|
||||
from traceback import format_exc
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
import yaml
|
||||
from typing import List
|
||||
from base_triple import get_llm_class
|
||||
from mrebeltriple import MRebelTripleExtractor
|
||||
from appPublic.registerfunction import RegisterFunction
|
||||
from appPublic.log import debug, exception
|
||||
from ahserver.serverenv import ServerEnv
|
||||
from ahserver.globalEnv import stream_response
|
||||
from ahserver.webapp import webserver
|
||||
import aiohttp.web
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger('llmengine_triple')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 加载配置文件
|
||||
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||
try:
|
||||
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
|
||||
raise RuntimeError(f"Failed to load config: {str(e)}")
|
||||
|
||||
helptext = """mREBEL Triplets API:
|
||||
|
||||
1. Triplets Endpoint:
|
||||
path: /v1/triples
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data: {
|
||||
"text": "知识图谱是一个结构化的语义知识库。"
|
||||
}
|
||||
response: {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"head": "知识图谱",
|
||||
"head_type": "Concept",
|
||||
"type": "is_a",
|
||||
"tail": "语义知识库",
|
||||
"tail_type": "Concept"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
2. Docs Endpoint:
|
||||
path: /v1/docs
|
||||
response: This help text
|
||||
"""
|
||||
|
||||
def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('triples', triples)
|
||||
rf.register('docs', docs)
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
|
||||
async def triples(request, params_kw, *params, **kw):
|
||||
try:
|
||||
debug(f'Received params_kw: {params_kw}')
|
||||
# 显式解析请求数据
|
||||
if not params_kw:
|
||||
try:
|
||||
data = await request.json()
|
||||
params_kw = data
|
||||
debug(f'Parsed JSON data: {params_kw}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse JSON: {str(e)}")
|
||||
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
|
||||
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
if engine is None:
|
||||
raise ValueError("Engine not initialized")
|
||||
|
||||
text = params_kw.get('text')
|
||||
if not text:
|
||||
e = ValueError("text cannot be empty")
|
||||
exception(f'{e}')
|
||||
raise e
|
||||
|
||||
triplets = await engine.extract_triplets(text)
|
||||
debug(f'{triplets=}')
|
||||
return {
|
||||
"object": "list",
|
||||
"data": triplets
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in triples endpoint: {str(e)}")
|
||||
logger.debug(f"Traceback: {format_exc()}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="mREBEL Triplet Service")
|
||||
parser.add_argument('-w', '--workdir', default=None)
|
||||
parser.add_argument('-p', '--port', type=int, default=9991)
|
||||
parser.add_argument('model_path')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
Klass = get_llm_class(args.model_path)
|
||||
if Klass is None:
|
||||
e = Exception(f"{args.model_path} has no mapping to a model class")
|
||||
exception(f'{e}, {format_exc()}')
|
||||
raise e
|
||||
|
||||
se = ServerEnv()
|
||||
se.engine = Klass(args.model_path)
|
||||
workdir = args.workdir or os.getcwd()
|
||||
port = args.port
|
||||
debug(f'{args=}')
|
||||
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
|
||||
webserver(init, workdir, port)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start server: {str(e)}")
|
||||
logger.debug(f"Traceback: {format_exc()}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
50
test/triples/conf/config.json
Normal file
50
test/triples/conf/config.json
Normal file
@ -0,0 +1,50 @@
|
||||
{
|
||||
"filesroot":"$[workdir]$/files",
|
||||
"logger":{
|
||||
"name":"llmengine",
|
||||
"levelname":"info",
|
||||
"logfile":"$[workdir]$/logs/llmengine.log"
|
||||
},
|
||||
"website":{
|
||||
"paths":[
|
||||
["$[workdir]$/wwwroot",""]
|
||||
],
|
||||
"client_max_size":10000,
|
||||
"host":"0.0.0.0",
|
||||
"port":9991,
|
||||
"coding":"utf-8",
|
||||
"indexes":[
|
||||
"index.html",
|
||||
"index.ui"
|
||||
],
|
||||
"startswiths":[
|
||||
{
|
||||
"leading":"/idfile",
|
||||
"registerfunction":"idfile"
|
||||
},{
|
||||
"leading": "/v1/triples",
|
||||
"registerfunction": "triples"
|
||||
},{
|
||||
"leading": "/docs",
|
||||
"registerfunction": "docs"
|
||||
}
|
||||
],
|
||||
"processors":[
|
||||
[".tmpl","tmpl"],
|
||||
[".app","app"],
|
||||
[".ui","bui"],
|
||||
[".dspy","dspy"],
|
||||
[".md","md"]
|
||||
],
|
||||
"rsakey_oops":{
|
||||
"privatekey":"$[workdir]$/conf/rsa_private_key.pem",
|
||||
"publickey":"$[workdir]$/conf/rsa_public_key.pem"
|
||||
},
|
||||
"session_max_time":3000,
|
||||
"session_issue_time":2500,
|
||||
"session_redis_notuse":{
|
||||
"url":"redis://127.0.0.1:6379"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
0
test/triples/logs/llmengine.log
Normal file
0
test/triples/logs/llmengine.log
Normal file
3
test/triples/start.sh
Executable file
3
test/triples/start.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.triple -p 9991 /share/models/Babelscape/mrebel-large
|
3
test/triples/stop.sh
Executable file
3
test/triples/stop.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
pkill -f "llmengine.triple.*9991"
|
Loading…
Reference in New Issue
Block a user