From 07b36f86922e7dda5b354417e3d1409a92843641 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Tue, 24 Jun 2025 19:34:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E4=B8=89=E5=85=83?= =?UTF-8?q?=E7=BB=84=E6=9C=8D=E5=8A=A1=E5=8C=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llmengine/base_triple.py | 54 +++++++++++ llmengine/mrebeltriple.py | 161 ++++++++++++++++++++++++++++++++ llmengine/triple.py | 134 ++++++++++++++++++++++++++ test/triples/conf/config.json | 50 ++++++++++ test/triples/logs/llmengine.log | 0 test/triples/start.sh | 3 + test/triples/stop.sh | 3 + 7 files changed, 405 insertions(+) create mode 100644 llmengine/base_triple.py create mode 100644 llmengine/mrebeltriple.py create mode 100644 llmengine/triple.py create mode 100644 test/triples/conf/config.json create mode 100644 test/triples/logs/llmengine.log create mode 100755 test/triples/start.sh create mode 100755 test/triples/stop.sh diff --git a/llmengine/base_triple.py b/llmengine/base_triple.py new file mode 100644 index 0000000..d7e39af --- /dev/null +++ b/llmengine/base_triple.py @@ -0,0 +1,54 @@ +import torch +from abc import ABC, abstractmethod +import logging +import os + +logger = logging.getLogger(__name__) + +model_pathMap = {} + + +def llm_register(model_key: str, Klass): + """Register a triplet extractor class for a given model key.""" + global model_pathMap + model_pathMap[model_key] = Klass + logger.debug(f"Registered {Klass.__name__} for model_key: {model_key}") + + +def get_llm_class(model_path: str): + """Return the triplet extractor class for the given model path.""" + for k, klass in model_pathMap.items(): + if k in model_path: + return klass + logger.debug(f"No class found for model_path: {model_path}, model_pathMap: {model_pathMap}") + return None + + +class BaseTripleExtractor(ABC): + """Base class for triplet extraction.""" + + def __init__(self, model_path: str): + self.model_path = model_path + self.model_name = os.path.basename(model_path) + self.model = None + logger.debug(f"Initialized BaseTripleExtractor with model_path: {model_path}") + + def use_mps_if_possible(self): + """Select device (MPS, CUDA, or CPU).""" + if torch.backends.mps.is_available(): + device = torch.device("mps") + logger.debug("Using MPS device") + elif torch.cuda.is_available(): + device = torch.device("cuda") + logger.debug("Using CUDA device") + else: + device = torch.device("cpu") + logger.debug("Using CPU device") + if self.model is not None: + self.model = self.model.to(device) + return device + + @abstractmethod + async def extract_triplets(self, text: str) -> list: + """Extract triplets from text.""" + pass \ No newline at end of file diff --git a/llmengine/mrebeltriple.py b/llmengine/mrebeltriple.py new file mode 100644 index 0000000..6703a42 --- /dev/null +++ b/llmengine/mrebeltriple.py @@ -0,0 +1,161 @@ +import os +import torch +import re +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +import logging +from base_triple import BaseTripleExtractor, llm_register + +logger = logging.getLogger(__name__) + +class MRebelTripleExtractor(BaseTripleExtractor): + def __init__(self, model_path: str): + super().__init__(model_path) + try: + logger.debug(f"Loading tokenizer from {model_path}") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + logger.debug(f"Loading model from {model_path}") + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) + self.device = self.use_mps_if_possible() + self.triplet_id = self.tokenizer.convert_tokens_to_ids("") + logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}") + except Exception as e: + logger.error(f"Failed to load mREBEL model: {str(e)}") + raise RuntimeError(f"Failed to load mREBEL model: {str(e)}") + + self.gen_kwargs = { + "max_length": 512, + "min_length": 10, + "length_penalty": 0.5, + "num_beams": 3, + "num_return_sequences": 1, + "no_repeat_ngram_size": 2, + "early_stopping": True, + "decoder_start_token_id": self.triplet_id, + } + + def split_document(self, text: str, max_chunk_size: int = 150) -> list: + """Split document into semantic chunks.""" + sentences = re.split(r'(?<=[。!?;\n])', text) + chunks = [] + current_chunk = "" + for sentence in sentences: + if len(current_chunk) + len(sentence) <= max_chunk_size: + current_chunk += sentence + else: + if current_chunk: + chunks.append(current_chunk) + current_chunk = sentence + if current_chunk: + chunks.append(current_chunk) + logger.debug(f"Text split into: {len(chunks)} chunks") + return chunks + + def extract_triplets_typed(self, text: str) -> list: + """Parse mREBEL generated text for triplets.""" + triplets = [] + logger.debug(f"Raw generated text: {text}") + + tokens = [] + in_tag = False + buffer = "" + for char in text: + if char == '<': + in_tag = True + if buffer: + tokens.append(buffer.strip()) + buffer = "" + buffer += char + elif char == '>': + in_tag = False + buffer += char + tokens.append(buffer.strip()) + buffer = "" + else: + buffer += char + if buffer: + tokens.append(buffer.strip()) + + special_tokens = ["", "", "", "tp_XX", "__en__", "__zh__", "zh_CN"] + tokens = [t for t in tokens if t not in special_tokens and t] + logger.debug(f"Processed tokens: {tokens}") + + i = 0 + while i < len(tokens): + if tokens[i] == "" and i + 5 < len(tokens): + entity1 = tokens[i + 1] + type1 = tokens[i + 2][1:-1] if tokens[i + 2].startswith("<") and tokens[i + 2].endswith(">") else "" + entity2 = tokens[i + 3] + type2 = tokens[i + 4][1:-1] if tokens[i + 4].startswith("<") and tokens[i + 4].endswith(">") else "" + relation = tokens[i + 5] + + if entity1 and type1 and entity2 and type2 and relation: + triplets.append({ + 'head': entity1.strip(), + 'head_type': type1, + 'type': relation.strip(), + 'tail': entity2.strip(), + 'tail_type': type2 + }) + logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})") + i += 6 + else: + i += 1 + + return triplets + + async def extract_triplets(self, text: str) -> list: + """Extract triplets from text and return unique triplets.""" + try: + if not text: + raise ValueError("Text cannot be empty") + + text_chunks = self.split_document(text, max_chunk_size=150) + logger.debug(f"Text split into {len(text_chunks)} chunks") + + all_triplets = [] + for i, chunk in enumerate(text_chunks): + logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...") + + model_inputs = self.tokenizer( + chunk, + max_length=256, + padding=True, + truncation=True, + return_tensors="pt" + ).to(self.device) + + try: + generated_tokens = self.model.generate( + model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + **self.gen_kwargs, + ) + decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) + for idx, sentence in enumerate(decoded_preds): + logger.debug(f"Chunk {i + 1} generated text: {sentence}") + triplets = self.extract_triplets_typed(sentence) + if triplets: + logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets") + all_triplets.extend(triplets) + except Exception as e: + logger.warning(f"Error processing chunk {i + 1}: {str(e)}") + continue + + unique_triplets = [] + seen = set() + for t in all_triplets: + identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower()) + if identifier not in seen: + seen.add(identifier) + unique_triplets.append(t) + + logger.info(f"Extracted {len(unique_triplets)} unique triplets") + return unique_triplets + + except Exception as e: + logger.error(f"Failed to extract triplets: {str(e)}") + import traceback + logger.debug(traceback.format_exc()) + return [] + +llm_register("mrebel-large", MRebelTripleExtractor) \ No newline at end of file diff --git a/llmengine/triple.py b/llmengine/triple.py new file mode 100644 index 0000000..a40e9eb --- /dev/null +++ b/llmengine/triple.py @@ -0,0 +1,134 @@ +from traceback import format_exc +import os +import argparse +import logging +import yaml +from typing import List +from base_triple import get_llm_class +from mrebeltriple import MRebelTripleExtractor +from appPublic.registerfunction import RegisterFunction +from appPublic.log import debug, exception +from ahserver.serverenv import ServerEnv +from ahserver.globalEnv import stream_response +from ahserver.webapp import webserver +import aiohttp.web + +# 配置日志 +logger = logging.getLogger('llmengine_triple') +logger.setLevel(logging.DEBUG) +log_file = '/share/wangmeihua/rag/logs/llmengine_triple.log' +os.makedirs(os.path.dirname(log_file), exist_ok=True) +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +for handler in (logging.FileHandler(log_file, encoding='utf-8'), logging.StreamHandler()): + handler.setFormatter(formatter) + logger.addHandler(handler) + +# 加载配置文件 +CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') +try: + with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) +except Exception as e: + logger.error(f"Failed to load config {CONFIG_PATH}: {str(e)}") + raise RuntimeError(f"Failed to load config: {str(e)}") + +helptext = """mREBEL Triplets API: + +1. Triplets Endpoint: +path: /v1/triples +headers: { + "Content-Type": "application/json" +} +data: { + "text": "知识图谱是一个结构化的语义知识库。" +} +response: { + "object": "list", + "data": [ + { + "head": "知识图谱", + "head_type": "Concept", + "type": "is_a", + "tail": "语义知识库", + "tail_type": "Concept" + }, + ... + ] +} + +2. Docs Endpoint: +path: /v1/docs +response: This help text +""" + +def init(): + rf = RegisterFunction() + rf.register('triples', triples) + rf.register('docs', docs) + +async def docs(request, params_kw, *params, **kw): + return helptext + +async def triples(request, params_kw, *params, **kw): + try: + debug(f'Received params_kw: {params_kw}') + # 显式解析请求数据 + if not params_kw: + try: + data = await request.json() + params_kw = data + debug(f'Parsed JSON data: {params_kw}') + except Exception as e: + logger.error(f"Failed to parse JSON: {str(e)}") + raise aiohttp.web.HTTPBadRequest(reason=f"Invalid JSON: {str(e)}") + + se = ServerEnv() + engine = se.engine + if engine is None: + raise ValueError("Engine not initialized") + + text = params_kw.get('text') + if not text: + e = ValueError("text cannot be empty") + exception(f'{e}') + raise e + + triplets = await engine.extract_triplets(text) + debug(f'{triplets=}') + return { + "object": "list", + "data": triplets + } + except Exception as e: + logger.error(f"Error in triples endpoint: {str(e)}") + logger.debug(f"Traceback: {format_exc()}") + raise + +def main(): + parser = argparse.ArgumentParser(prog="mREBEL Triplet Service") + parser.add_argument('-w', '--workdir', default=None) + parser.add_argument('-p', '--port', type=int, default=9991) + parser.add_argument('model_path') + args = parser.parse_args() + + try: + Klass = get_llm_class(args.model_path) + if Klass is None: + e = Exception(f"{args.model_path} has no mapping to a model class") + exception(f'{e}, {format_exc()}') + raise e + + se = ServerEnv() + se.engine = Klass(args.model_path) + workdir = args.workdir or os.getcwd() + port = args.port + debug(f'{args=}') + logger.info(f"Starting mREBEL Triplet Service on port {port}, model: {args.model_path}") + webserver(init, workdir, port) + except Exception as e: + logger.error(f"Failed to start server: {str(e)}") + logger.debug(f"Traceback: {format_exc()}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/triples/conf/config.json b/test/triples/conf/config.json new file mode 100644 index 0000000..e4276ec --- /dev/null +++ b/test/triples/conf/config.json @@ -0,0 +1,50 @@ +{ + "filesroot":"$[workdir]$/files", + "logger":{ + "name":"llmengine", + "levelname":"info", + "logfile":"$[workdir]$/logs/llmengine.log" + }, + "website":{ + "paths":[ + ["$[workdir]$/wwwroot",""] + ], + "client_max_size":10000, + "host":"0.0.0.0", + "port":9991, + "coding":"utf-8", + "indexes":[ + "index.html", + "index.ui" + ], + "startswiths":[ + { + "leading":"/idfile", + "registerfunction":"idfile" + },{ + "leading": "/v1/triples", + "registerfunction": "triples" + },{ + "leading": "/docs", + "registerfunction": "docs" + } + ], + "processors":[ + [".tmpl","tmpl"], + [".app","app"], + [".ui","bui"], + [".dspy","dspy"], + [".md","md"] + ], + "rsakey_oops":{ + "privatekey":"$[workdir]$/conf/rsa_private_key.pem", + "publickey":"$[workdir]$/conf/rsa_public_key.pem" + }, + "session_max_time":3000, + "session_issue_time":2500, + "session_redis_notuse":{ + "url":"redis://127.0.0.1:6379" + } + } +} + diff --git a/test/triples/logs/llmengine.log b/test/triples/logs/llmengine.log new file mode 100644 index 0000000..e69de29 diff --git a/test/triples/start.sh b/test/triples/start.sh new file mode 100755 index 0000000..0abcc35 --- /dev/null +++ b/test/triples/start.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=7 /share/vllm-0.8.5/bin/python -m llmengine.triple -p 9991 /share/models/Babelscape/mrebel-large \ No newline at end of file diff --git a/test/triples/stop.sh b/test/triples/stop.sh new file mode 100755 index 0000000..d48070a --- /dev/null +++ b/test/triples/stop.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pkill -f "llmengine.triple.*9991" \ No newline at end of file