:x
Merge branch 'main' of git.kaiyuancloud.cn:yumoqing/llmengine
This commit is contained in:
commit
5f1b06d10f
54
llmengine/base_triple.py
Normal file
54
llmengine/base_triple.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
model_pathMap = {}
|
||||||
|
|
||||||
|
|
||||||
|
def llm_register(model_key: str, Klass):
|
||||||
|
"""Register a triplet extractor class for a given model key."""
|
||||||
|
global model_pathMap
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_class(model_path: str):
|
||||||
|
"""Return the triplet extractor class for the given model path."""
|
||||||
|
for k, klass in model_pathMap.items():
|
||||||
|
if k in model_path:
|
||||||
|
return klass
|
||||||
|
logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTripleExtractor(ABC):
|
||||||
|
"""Base class for triplet extraction."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model_name = os.path.basename(model_path)
|
||||||
|
self.model = None
|
||||||
|
logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}")
|
||||||
|
|
||||||
|
def use_mps_if_possible(self):
|
||||||
|
"""Select device (MPS, CUDA, or CPU)."""
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
logger.debug("Using MPS device")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
logger.debug("Using CUDA device")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logger.debug("Using CPU device")
|
||||||
|
if self.model is not None:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
return device
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract_triplets(self, text: str) -> list:
|
||||||
|
"""Extract triplets from text."""
|
||||||
|
pass
|
161
llmengine/mrebeltriple.py
Normal file
161
llmengine/mrebeltriple.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
import logging
|
||||||
|
from base_triple import BaseTripleExtractor, llm_register
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MRebelTripleExtractor(BaseTripleExtractor):
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
super().__init__(model_path)
|
||||||
|
try:
|
||||||
|
logger.debug(f"Loading tokenizer from {model_path}")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
logger.debug(f"Loading model from {model_path}")
|
||||||
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||||||
|
self.device = self.use_mps_if_possible()
|
||||||
|
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
|
||||||
|
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load mREBEL model: {str(e)}")
|
||||||
|
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
|
||||||
|
|
||||||
|
self.gen_kwargs = {
|
||||||
|
"max_length": 512,
|
||||||
|
"min_length": 10,
|
||||||
|
"length_penalty": 0.5,
|
||||||
|
"num_beams": 3,
|
||||||
|
"num_return_sequences": 1,
|
||||||
|
"no_repeat_ngram_size": 2,
|
||||||
|
"early_stopping": True,
|
||||||
|
"decoder_start_token_id": self.triplet_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def split_document(self, text: str, max_chunk_size: int = 150) -> list:
|
||||||
|
"""Split document into semantic chunks."""
|
||||||
|
sentences = re.split(r'(?<=[。!?;\n])', text)
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(current_chunk) + len(sentence) <= max_chunk_size:
|
||||||
|
current_chunk += sentence
|
||||||
|
else:
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
current_chunk = sentence
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
logger.debug(f"Text split into: {len(chunks)} chunks")
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_triplets_typed(self, text: str) -> list:
|
||||||
|
"""Parse mREBEL generated text for triplets."""
|
||||||
|
triplets = []
|
||||||
|
logger.debug(f"Raw generated text: {text}")
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
in_tag = False
|
||||||
|
buffer = ""
|
||||||
|
for char in text:
|
||||||
|
if char == '<':
|
||||||
|
in_tag = True
|
||||||
|
if buffer:
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
buffer = ""
|
||||||
|
buffer += char
|
||||||
|
elif char == '>':
|
||||||
|
in_tag = False
|
||||||
|
buffer += char
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
buffer = ""
|
||||||
|
else:
|
||||||
|
buffer += char
|
||||||
|
if buffer:
|
||||||
|
tokens.append(buffer.strip())
|
||||||
|
|
||||||
|
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||||||
|
tokens = [t for t in tokens if t not in special_tokens and t]
|
||||||
|
logger.debug(f"Processed tokens: {tokens}")
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
if tokens[i] == "<triplet>" and i + 5 < len(tokens):
|
||||||
|
entity1 = tokens[i + 1]
|
||||||
|
type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else ""
|
||||||
|
entity2 = tokens[i + 3]
|
||||||
|
type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else ""
|
||||||
|
relation = tokens[i + 5]
|
||||||
|
|
||||||
|
if entity1 and type1 and entity2 and type2 and relation:
|
||||||
|
triplets.append({
|
||||||
|
'head': entity1.strip(),
|
||||||
|
'head_type': type1,
|
||||||
|
'type': relation.strip(),
|
||||||
|
'tail': entity2.strip(),
|
||||||
|
'tail_type': type2
|
||||||
|
})
|
||||||
|
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||||
|
i += 6
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return triplets
|
||||||
|
|
||||||
|
async def extract_triplets(self, text: str) -> list:
|
||||||
|
"""Extract triplets from text and return unique triplets."""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Text cannot be empty")
|
||||||
|
|
||||||
|
text_chunks = self.split_document(text, max_chunk_size=150)
|
||||||
|
logger.debug(f"Text split into {len(text_chunks)} chunks")
|
||||||
|
|
||||||
|
all_triplets = []
|
||||||
|
for i, chunk in enumerate(text_chunks):
|
||||||
|
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||||
|
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
chunk,
|
||||||
|
max_length=256,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
try:
|
||||||
|
generated_tokens = self.model.generate(
|
||||||
|
model_inputs["input_ids"],
|
||||||
|
attention_mask=model_inputs["attention_mask"],
|
||||||
|
**self.gen_kwargs,
|
||||||
|
)
|
||||||
|
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||||||
|
for idx, sentence in enumerate(decoded_preds):
|
||||||
|
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
|
||||||
|
triplets = self.extract_triplets_typed(sentence)
|
||||||
|
if triplets:
|
||||||
|
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
||||||
|
all_triplets.extend(triplets)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
unique_triplets = []
|
||||||
|
seen = set()
|
||||||
|
for t in all_triplets:
|
||||||
|
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
|
||||||
|
if identifier not in seen:
|
||||||
|
seen.add(identifier)
|
||||||
|
unique_triplets.append(t)
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
|
||||||
|
return unique_triplets
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to extract triplets: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
|
llm_register("mrebel-large", MRebelTripleExtractor)
|
134
llmengine/triple.py
Normal file
134
llmengine/triple.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
from typing import List
|
||||||
|
from base_triple import get_llm_class
|
||||||
|
from mrebeltriple import MRebelTripleExtractor
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
import aiohttp.web
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger = logging.getLogger('llmengine_triple')
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log'
|
||||||
|
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()):
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml')
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}")
|
||||||
|
raise RuntimeError(f"Failed to load config: {str(e)}")
|
||||||
|
|
||||||
|
helptext = """mREBEL Triplets API:
|
||||||
|
|
||||||
|
1. Triplets Endpoint:
|
||||||
|
path: /v1/triples
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"text": "知识图谱是一个结构化的语义知识库。"
|
||||||
|
}
|
||||||
|
response: {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"head": "知识图谱",
|
||||||
|
"head_type": "Concept",
|
||||||
|
"type": "is_a",
|
||||||
|
"tail": "语义知识库",
|
||||||
|
"tail_type": "Concept"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
2. Docs Endpoint:
|
||||||
|
path: /v1/docs
|
||||||
|
response: This help text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('triples', triples)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def triples(request, params_kw, *params, **kw):
|
||||||
|
try:
|
||||||
|
debug(f'Received params_kw: {params_kw}')
|
||||||
|
# 显式解析请求数据
|
||||||
|
if not params_kw:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
params_kw = data
|
||||||
|
debug(f'Parsed JSON data: {params_kw}')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse JSON: {str(e)}")
|
||||||
|
raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}")
|
||||||
|
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
if engine is None:
|
||||||
|
raise ValueError("Engine not initialized")
|
||||||
|
|
||||||
|
text = params_kw.get('text')
|
||||||
|
if not text:
|
||||||
|
e = ValueError("text cannot be empty")
|
||||||
|
exception(f'{e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
triplets = await engine.extract_triplets(text)
|
||||||
|
debug(f'{triplets=}')
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": triplets
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in triples endpoint: {str(e)}")
|
||||||
|
logger.debug(f"Traceback: {format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="mREBEL Triplet Service")
|
||||||
|
parser.add_argument('-w', '--workdir', default=None)
|
||||||
|
parser.add_argument('-p', '--port', type=int, default=9991)
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f"{args.model_path} has no mapping to a model class")
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}")
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start server: {str(e)}")
|
||||||
|
logger.debug(f"Traceback: {format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
50
test/triples/conf/config.json
Normal file
50
test/triples/conf/config.json
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
{
|
||||||
|
"filesroot":"$[workdir]$/files",
|
||||||
|
"logger":{
|
||||||
|
"name":"llmengine",
|
||||||
|
"levelname":"info",
|
||||||
|
"logfile":"$[workdir]$/logs/llmengine.log"
|
||||||
|
},
|
||||||
|
"website":{
|
||||||
|
"paths":[
|
||||||
|
["$[workdir]$/wwwroot",""]
|
||||||
|
],
|
||||||
|
"client_max_size":10000,
|
||||||
|
"host":"0.0.0.0",
|
||||||
|
"port":9991,
|
||||||
|
"coding":"utf-8",
|
||||||
|
"indexes":[
|
||||||
|
"index.html",
|
||||||
|
"index.ui"
|
||||||
|
],
|
||||||
|
"startswiths":[
|
||||||
|
{
|
||||||
|
"leading":"/idfile",
|
||||||
|
"registerfunction":"idfile"
|
||||||
|
},{
|
||||||
|
"leading": "/v1/triples",
|
||||||
|
"registerfunction": "triples"
|
||||||
|
},{
|
||||||
|
"leading": "/docs",
|
||||||
|
"registerfunction": "docs"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"processors":[
|
||||||
|
[".tmpl","tmpl"],
|
||||||
|
[".app","app"],
|
||||||
|
[".ui","bui"],
|
||||||
|
[".dspy","dspy"],
|
||||||
|
[".md","md"]
|
||||||
|
],
|
||||||
|
"rsakey_oops":{
|
||||||
|
"privatekey":"$[workdir]$/conf/rsa_private_key.pem",
|
||||||
|
"publickey":"$[workdir]$/conf/rsa_public_key.pem"
|
||||||
|
},
|
||||||
|
"session_max_time":3000,
|
||||||
|
"session_issue_time":2500,
|
||||||
|
"session_redis_notuse":{
|
||||||
|
"url":"redis://127.0.0.1:6379"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
0
test/triples/logs/llmengine.log
Normal file
0
test/triples/logs/llmengine.log
Normal file
3
test/triples/start.sh
Executable file
3
test/triples/start.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.triple -p 9991 /share/models/Babelscape/mrebel-large
|
3
test/triples/stop.sh
Executable file
3
test/triples/stop.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
pkill -f "llmengine.triple.*9991"
|
Loading…
Reference in New Issue
Block a user