diff --git a/llmengine/connection.py b/llmengine/connection.py index fbe9eb6..53f5ed1 100644 --- a/llmengine/connection.py +++ b/llmengine/connection.py @@ -80,25 +80,43 @@ data: { "use_rerank": true } response: -- Success: HTTP 200, [ - { - "text": "<完整文本内容>", - "distance": 0.95, - "source": "fused_query_with_triplets", - "rerank_score": 0.92, // 若 use_rerank=true - "metadata": { - "userid": "user1", - "document_id": "", - "filename": "file.txt", - "file_path": "/path/to/file.txt", - "upload_time": "", - "file_type": "txt" - } +- Success: HTTP 200, { + "status": "success", + "results": [ + { + "text": "<完整文本内容>", + "distance": 0.95, + "source": "fused_query_with_triplets", + "rerank_score": 0.92, // 若 use_rerank=true + "metadata": { + "userid": "user1", + "document_id": "", + "filename": "file.txt", + "file_path": "/path/to/file.txt", + "upload_time": "", + "file_type": "txt" + } + }, + ... + ], + "timing": { + "collection_load": , // 集合加载耗时(秒) + "entity_extraction": , // 实体提取耗时(秒) + "triplet_matching": , // 三元组匹配耗时(秒) + "triplet_text_combine": , // 拼接三元组文本耗时(秒) + "embedding_generation": , // 嵌入向量生成耗时(秒) + "vector_search": , // 向量搜索耗时(秒) + "deduplication": , // 去重耗时(秒) + "reranking": , // 重排序耗时(秒,若 use_rerank=true) + "total_time": // 总耗时(秒) }, - ... -] -- Error: HTTP 400, {"status": "error", "message": "", "collection_name": ""} - + "collection_name": "ragdb" or "ragdb_textdb" +} +- Error: HTTP 400, { + "status": "error", + "message": "", + "collection_name": "ragdb" or "ragdb_textdb" +} 6. Search Query Endpoint: path: /v1/searchquery method: POST @@ -113,45 +131,83 @@ data: { "use_rerank": true } response: -- Success: HTTP 200, [ - { - "text": "<完整文本内容>", - "distance": 0.95, - "source": "vector_query", - "rerank_score": 0.92, // 若 use_rerank=true - "metadata": { - "userid": "user1", - "document_id": "", - "filename": "file.txt", - "file_path": "/path/to/file.txt", - "upload_time": "", - "file_type": "txt" - } +- Success: HTTP 200, { + "status": "success", + "results": [ + { + "text": "<完整文本内容>", + "distance": 0.95, + "source": "vector_query", + "rerank_score": 0.92, // 若 use_rerank=true + "metadata": { + "userid": "user1", + "document_id": "", + "filename": "file.txt", + "file_path": "/path/to/file.txt", + "upload_time": "", + "file_type": "txt" + } + }, + ... + ], + "timing": { + "collection_load": , // 集合加载耗时(秒) + "embedding_generation": , // 嵌入向量生成耗时(秒) + "vector_search": , // 向量搜索耗时(秒) + "deduplication": , // 去重耗时(秒) + "reranking": , // 重排序耗时(秒,若 use_rerank=true) + "total_time": // 总耗时(秒) }, - ... -] -- Error: HTTP 400, {"status": "error", "message": ""} + "collection_name": "ragdb" or "ragdb_textdb" +} +- Error: HTTP 400, { + "status": "error", + "message": "", + "collection_name": "ragdb" or "ragdb_textdb" +} 7. List User Files Endpoint: path: /v1/listuserfiles method: POST headers: {"Content-Type": "application/json"} data: { - "userid": "testuser2" + "userid": "user1", + "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb } response: -- Success: HTTP 200, [ - { - "filename": "file.txt", - "file_path": "/path/to/file.txt", - "db_type": "textdb", - "upload_time": "", - "file_type": "txt" +- Success: HTTP 200, { + "status": "success", + "files_by_knowledge_base": { + "kb123": [ + { + "document_id": "", + "filename": "file1.txt", + "file_path": "/path/to/file1.txt", + "upload_time": "", + "file_type": "txt", + "knowledge_base_id": "kb123" + }, + ... + ], + "kb456": [ + { + "document_id": "", + "filename": "file2.pdf", + "file_path": "/path/to/file2.pdf", + "upload_time": "", + "file_type": "pdf", + "knowledge_base_id": "kb456" + }, + ... + ] }, - ... -] -- Error: HTTP 400, {"status": "error", "message": ""} - + "collection_name": "ragdb" or "ragdb_textdb" +} +- Error: HTTP 400, { + "status": "error", + "message": "", + "collection_name": "ragdb" or "ragdb_textdb" +} 8. Connection Endpoint (for compatibility): path: /v1/connection method: POST @@ -378,7 +434,13 @@ async def fused_search_query(request, params_kw, *params, **kw): "use_rerank": use_rerank }) debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + response = { + "status": "success", + "results": result.get("results", []), + "timing": result.get("timing", {}), + "collection_name": collection_name + } + return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: error(f'融合搜索失败: {str(e)}') return web.json_response({ @@ -396,7 +458,7 @@ async def search_query(request, params_kw, *params, **kw): db_type = params_kw.get('db_type', '') knowledge_base_ids = params_kw.get('knowledge_base_ids') limit = params_kw.get('limit') - offset = params_kw.get('offset') + offset = params_kw.get('offset', 0) use_rerank = params_kw.get('use_rerank', True) collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: @@ -417,7 +479,13 @@ async def search_query(request, params_kw, *params, **kw): "use_rerank": use_rerank }) debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + response = { + "status": "success", + "results": result.get("results", []), + "timing": result.get("timing", {}), + "collection_name": collection_name + } + return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: error(f'纯向量搜索失败: {str(e)}') return web.json_response({ @@ -431,23 +499,33 @@ async def list_user_files(request, params_kw, *params, **kw): se = ServerEnv() engine = se.engine userid = params_kw.get('userid') - if not userid: - debug(f'userid 未提供') - return web.json_response({ - "status": "error", - "message": "userid 参数未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + db_type = params_kw.get('db_type', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: + if not userid: + debug(f'userid 未提供') + return web.json_response({ + "status": "error", + "message": "userid 未提供", + "collection_name": collection_name + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) result = await engine.handle_connection("list_user_files", { - "userid": userid + "userid": userid, + "db_type": db_type }) debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + response = { + "status": "success", + "files_by_knowledge_base": result, + "collection_name": collection_name + } + return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - error(f'查询用户文件列表失败: {str(e)}') + error(f'列出用户文件失败: {str(e)}') return web.json_response({ "status": "error", - "message": str(e) + "message": str(e), + "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) async def handle_connection(request, params_kw, *params, **kw): diff --git a/llmengine/entity.py b/llmengine/entity.py index 865286d..b1f1205 100644 --- a/llmengine/entity.py +++ b/llmengine/entity.py @@ -45,6 +45,7 @@ def init(): rf = RegisterFunction() rf.register('entities', entities) rf.register('docs', docs) + debug("注册路由: entities, docs") async def docs(request, params_kw, *params, **kw): return helptext @@ -53,12 +54,11 @@ async def entities(request, params_kw, *params, **kw): debug(f'{params_kw.query=}') se = ServerEnv() engine = se.engine - f = awaitify(engine.extract_entities) query = params_kw.query if query is None: e = exception(f'query is None') raise e - entities = await f(query) + entities = await engine.extract_entities(query) debug(f'{entities=}, type(entities)') return { "object": "list", diff --git a/llmengine/kgc.py b/llmengine/kgc.py index 35b6ede..56edefa 100644 --- a/llmengine/kgc.py +++ b/llmengine/kgc.py @@ -1,19 +1,23 @@ -import logging import os import re from py2neo import Graph, Node, Relationship from typing import Set, List, Dict, Tuple +from appPublic.jsonConfig import getConfig +from appPublic.log import debug, error, info, exception -# 配置日志 -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) class KnowledgeGraph: - def __init__(self, triples: List[Dict], document_id: str): + def __init__(self, triples: List[Dict], document_id: str, knowledge_base_id: str, userid: str): self.triples = triples self.document_id = document_id - self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh')) - logger.info(f"开始构建知识图谱,document_id: {self.document_id}, 三元组数量: {len(triples)}") + self.knowledge_base_id = knowledge_base_id + self.userid = userid + config = getConfig() + self.neo4j_uri = config['neo4j']['uri'] + self.neo4j_user = config['neo4j']['user'] + self.neo4j_password = config['neo4j']['password'] + self.g = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) + info(f"开始构建知识图谱,document_id: {self.document_id}, knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid}, 三元组数量: {len(triples)}") def _normalize_label(self, entity_type: str) -> str: """规范化实体类型为 Neo4j 标签""" @@ -25,28 +29,29 @@ class KnowledgeGraph: return label or 'Entity' def _clean_relation(self, relation: str) -> Tuple[str, str]: - """清洗关系,返回 (rel_type, rel_name)""" + """清洗关系,返回 (rel_type, rel_name),确保 rel_type 合法""" relation = relation.strip() if not relation: return 'RELATED_TO', '相关' - if relation.startswith('<') and relation.endswith('>'): - cleaned_relation = relation[1:-1] - rel_name = cleaned_relation - rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper() - else: - rel_name = relation - rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper() - if 'instance of' in relation.lower(): - rel_type = 'INSTANCE_OF' - rel_name = '实例' - elif 'subclass of' in relation.lower(): - rel_type = 'SUBCLASS_OF' - rel_name = '子类' - elif 'part of' in relation.lower(): - rel_type = 'PART_OF' - rel_name = '部分' - logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})") - return rel_type, rel_name + + cleaned_relation = re.sub(r'[^\w\s]', '', relation).strip() + if not cleaned_relation: + return 'RELATED_TO', '相关' + + if 'instance of' in relation.lower(): + return 'INSTANCE_OF', '实例' + elif 'subclass of' in relation.lower(): + return 'SUBCLASS_OF', '子类' + elif 'part of' in relation.lower(): + return 'PART_OF', '部分' + + rel_type = re.sub(r'\s+', '_', cleaned_relation).upper() + if rel_type and rel_type[0].isdigit(): + rel_type = f'REL_{rel_type}' + if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', rel_type): + debug(f"非法关系类型 '{rel_type}',替换为 'RELATED_TO'") + return 'RELATED_TO', relation + return rel_type, relation def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]: """从三元组列表中读取节点和关系""" @@ -57,14 +62,14 @@ class KnowledgeGraph: try: for triple in self.triples: if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']): - logger.warning(f"无效三元组: {triple}") + debug(f"无效三元组: {triple}") continue head, relation, tail, head_type, tail_type = ( triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type'] ) head_label = self._normalize_label(head_type) tail_label = self._normalize_label(tail_type) - logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}") + debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}") if head_label not in nodes_by_label: nodes_by_label[head_label] = set() @@ -92,33 +97,44 @@ class KnowledgeGraph: 'tail_type': tail_type }) - logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个") - logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条") + info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个") + info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条") return nodes_by_label, relations_by_type, triples except Exception as e: - logger.error(f"读取三元组失败: {str(e)}") + error(f"读取三元组失败: {str(e)}") raise RuntimeError(f"读取三元组失败: {str(e)}") def create_node(self, label: str, nodes: Set[str]): - """创建节点,包含 document_id 属性""" + """创建节点,包含 document_id、knowledge_base_id 和 userid 属性""" count = 0 for node_name in nodes: - query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n" + query = ( + f"MATCH (n:{label} {{name: $name, document_id: $doc_id, " + f"knowledge_base_id: $kb_id, userid: $userid}}) RETURN n" + ) try: - if self.g.run(query).data(): + if self.g.run(query, name=node_name, doc_id=self.document_id, + kb_id=self.knowledge_base_id, userid=self.userid).data(): continue - node = Node(label, name=node_name, document_id=self.document_id) + node = Node( + label, + name=node_name, + document_id=self.document_id, + knowledge_base_id=self.knowledge_base_id, + userid=self.userid + ) self.g.create(node) count += 1 - logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})") + debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id}, " + f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})") except Exception as e: - logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}") - logger.info(f"创建 {label} 节点: {count}/{len(nodes)} 个") + error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}") + info(f"创建 {label} 节点: {count}/{len(nodes)} 个") return count def create_relationship(self, rel_type: str, relations: List[Dict]): - """创建关系""" + """创建关系,包含 document_id、knowledge_base_id 和 userid 属性""" count = 0 total = len(relations) seen_edges = set() @@ -132,17 +148,23 @@ class KnowledgeGraph: seen_edges.add(edge_key) query = ( - f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), " - f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) " - f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)" + f"MATCH (p:{head_label} {{name: $head, document_id: $doc_id, " + f"knowledge_base_id: $kb_id, userid: $userid}}), " + f"(q:{tail_label} {{name: $tail, document_id: $doc_id, " + f"knowledge_base_id: $kb_id, userid: $userid}}) " + f"CREATE (p)-[r:{rel_type} {{name: $rel_name, document_id: $doc_id, " + f"knowledge_base_id: $kb_id, userid: $userid}}]->(q)" ) try: - self.g.run(query) + self.g.run(query, head=head, tail=tail, rel_name=rel_name, + doc_id=self.document_id, kb_id=self.knowledge_base_id, + userid=self.userid) count += 1 - logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})") + debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id}, " + f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})") except Exception as e: - logger.error(f"创建关系失败: {query}, 错误: {str(e)}") - logger.info(f"创建 {rel_type} 关系: {count}/{total} 条") + error(f"创建关系失败: {query}, 错误: {str(e)}") + info(f"创建 {rel_type} 关系: {count}/{total} 条") return count def create_graphnodes(self): @@ -151,7 +173,7 @@ class KnowledgeGraph: total = 0 for label, nodes in nodes_by_label.items(): total += self.create_node(label, nodes) - logger.info(f"总计创建节点: {total} 个") + info(f"总计创建节点: {total} 个") return total def create_graphrels(self): @@ -160,15 +182,16 @@ class KnowledgeGraph: total = 0 for rel_type, relations in relations_by_type.items(): total += self.create_relationship(rel_type, relations) - logger.info(f"总计创建关系: {total} 条") + info(f"总计创建关系: {total} 条") return total def export_data(self): - """导出节点到文件,包含 document_id""" + """导出节点到文件,包含 document_id、knowledge_base_id 和 userid""" nodes_by_label, _, _ = self.read_nodes() os.makedirs('dict', exist_ok=True) for label, nodes in nodes_by_label.items(): with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f: - f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes))) - logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个") + f.write('\n'.join(f"{name}\t{self.document_id}\t{self.knowledge_base_id}\t{self.userid}" + for name in sorted(nodes))) + info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个") return \ No newline at end of file diff --git a/llmengine/ltpentity.py b/llmengine/ltpentity.py index 300c048..337e79c 100644 --- a/llmengine/ltpentity.py +++ b/llmengine/ltpentity.py @@ -1,11 +1,9 @@ -# Requires ltp>=0.2.0 - from ltp import LTP from typing import List -import logging +from appPublic.log import debug, info, error +from appPublic.worker import awaitify from llmengine.base_entity import BaseLtp, ltp_register - -logger = logging.getLogger(__name__) +import asyncio class LtpEntity(BaseLtp): def __init__(self, model_id): @@ -14,7 +12,7 @@ class LtpEntity(BaseLtp): self.model_id = model_id self.model_name = model_id.split('/')[-1] - def extract_entities(self, query: str) -> List[str]: + async def extract_entities(self, query: str) -> List[str]: """ 从查询文本中抽取实体,包括: - LTP NER 识别的实体(所有类型)。 @@ -26,7 +24,18 @@ class LtpEntity(BaseLtp): if not query: raise ValueError("查询文本不能为空") - result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"]) + # 定义同步 pipeline 函数,正确传递 tasks 参数 + def sync_pipeline(query, tasks): + return self.ltp.pipeline([query], tasks=tasks) + + # 使用 run_in_executor 运行同步 pipeline + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: sync_pipeline(query, ["cws", "pos", "ner"]) + ) + + # 解析结果 words = result.cws[0] pos_list = result.pos[0] ner = result.ner[0] @@ -34,7 +43,7 @@ class LtpEntity(BaseLtp): entities = [] subword_set = set() - logger.debug(f"NER 结果: {ner}") + debug(f"NER 结果: {ner}") for entity_type, entity, start, end in ner: entities.append(entity) @@ -49,13 +58,13 @@ class LtpEntity(BaseLtp): if combined: entities.append(combined) subword_set.update(combined_words) - logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}") + debug(f"合并连续名词: {combined}, 子词: {combined_words}") combined = "" combined_words = [] else: combined = "" combined_words = [] - logger.debug(f"连续名词子词集合: {subword_set}") + debug(f"连续名词子词集合: {subword_set}") for word, pos in zip(words, pos_list): if pos == 'n' and word not in subword_set: @@ -66,11 +75,11 @@ class LtpEntity(BaseLtp): entities.append(word) unique_entities = list(dict.fromkeys(entities)) - logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}") + info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}") return unique_entities except Exception as e: - logger.error(f"实体抽取失败: {str(e)}") - return [] + error(f"实体抽取失败: {str(e)}") + raise # 抛出异常以便调试,而不是返回空列表 ltp_register('LTP', LtpEntity) \ No newline at end of file diff --git a/llmengine/milvus_connection.py b/llmengine/milvus_connection.py index 5ba3303..a0ecef4 100644 --- a/llmengine/milvus_connection.py +++ b/llmengine/milvus_connection.py @@ -16,6 +16,7 @@ from llmengine.kgc import KnowledgeGraph import numpy as np from py2neo import Graph from scipy.spatial.distance import cosine +import time # 嵌入缓存 EMBED_CACHE = {} @@ -182,7 +183,7 @@ class MilvusConnection: if not userid: return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} - return await self.list_user_files(userid) + return await self._list_user_files(userid) else: return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name, "document_id": "", "status_code": 400} @@ -814,20 +815,24 @@ class MilvusConnection: error(f"实体识别服务调用失败: {str(e)}\n{exception()}") return [] - async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]: + async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]: """匹配查询实体与 Neo4j 中的三元组""" + start_time = time.time() # 记录开始时间 matched_triplets = [] ENTITY_SIMILARITY_THRESHOLD = 0.8 try: graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) debug(f"已连接到 Neo4j: {self.neo4j_uri}") + neo4j_connect_time = time.time() - start_time + debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f} 秒") matched_names = set() + entity_match_start = time.time() for entity in query_entities: normalized_entity = entity.lower().strip() query = """ - MATCH (n {document_id: $document_id}) + MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id}) WHERE toLower(n.name) CONTAINS $entity OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7 RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim @@ -835,24 +840,27 @@ class MilvusConnection: LIMIT 100 """ try: - results = graph.run(query, document_id=document_id, entity=normalized_entity).data() + results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data() for record in results: matched_names.add(record['n.name']) debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})") except Exception as e: debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}") continue + entity_match_time = time.time() - entity_match_start + debug(f"实体匹配耗时: {entity_match_time:.3f} 秒") triplets = [] if matched_names: + triplet_query_start = time.time() query = """ - MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id}) + MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id}) WHERE h.name IN $matched_names OR t.name IN $matched_names RETURN h.name AS head, r.name AS type, t.name AS tail LIMIT 100 """ try: - results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data() + results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data() seen = set() for record in results: head, type_, tail = record['head'], record['type'], record['tail'] @@ -866,22 +874,28 @@ class MilvusConnection: 'head_type': '', 'tail_type': '' }) - debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}") + debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}") except Exception as e: - error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}") + error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}\n{exception()}") return [] + triplet_query_time = time.time() - triplet_query_start + debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f} 秒") if not triplets: - debug(f"文档 document_id={document_id} 无匹配三元组") + debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组") return [] + embedding_start = time.time() texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets] embeddings = await self._get_embeddings(texts_to_embed) entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)} head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)} tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)} debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)") + embedding_time = time.time() - embedding_start + debug(f"嵌入向量生成耗时: {embedding_time:.3f} 秒") + similarity_start = time.time() for entity in query_entities: entity_vec = entity_vectors[entity] for d_triplet in triplets: @@ -894,6 +908,8 @@ class MilvusConnection: matched_triplets.append(d_triplet) debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") + similarity_time = time.time() - similarity_start + debug(f"相似度计算耗时: {similarity_time:.3f} 秒") unique_matched = [] seen = set() @@ -903,6 +919,8 @@ class MilvusConnection: seen.add(identifier) unique_matched.append(t) + total_time = time.time() - start_time + debug(f"_match_triplets 总耗时: {total_time:.3f} 秒") info(f"找到 {len(unique_matched)} 个匹配的三元组") return unique_matched @@ -957,9 +975,11 @@ class MilvusConnection: return results async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5, - offset: int = 0, use_rerank: bool = True) -> List[Dict]: + offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]: """融合搜索,将查询与所有三元组拼接后向量化搜索""" + start_time = time.time() # 记录开始时间 collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + timing_stats = {} # 记录各步骤耗时 try: info( f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") @@ -975,46 +995,38 @@ class MilvusConnection: if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") - return [] + return {"results": [], "timing": timing_stats} try: collection = Collection(collection_name) collection.load() debug(f"加载集合: {collection_name}") + timing_stats["collection_load"] = time.time() - start_time + debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") - return [] + return {"results": [], "timing": timing_stats} + entity_extract_start = time.time() query_entities = await self._extract_entities(query) - debug(f"提取实体: {query_entities}") + timing_stats["entity_extraction"] = time.time() - entity_extract_start + debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒") - documents = [] all_triplets = [] + triplet_match_start = time.time() for kb_id in knowledge_base_ids: debug(f"处理知识库: {kb_id}") + matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id) + debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)} 条") + all_triplets.extend(matched_triplets) + timing_stats["triplet_matching"] = time.time() - triplet_match_start + debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒") - results = collection.query( - expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'", - output_fields=["document_id", "filename", "knowledge_base_id"], - limit=100 # 查询足够多的文档以支持后续过滤 - ) - if not results: - debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档") - continue - documents.extend(results) - - for doc in results: - document_id = doc["document_id"] - matched_triplets = await self._match_triplets(query, query_entities, userid, document_id) - debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)} 条") - all_triplets.extend(matched_triplets) - - if not documents: - debug("未找到任何有效文档") - return [] - - info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}") + if not all_triplets: + debug("未找到任何匹配的三元组") + return {"results": [], "timing": timing_stats} + triplet_text_start = time.time() triplet_texts = [] for triplet in all_triplets: head = triplet.get('head', '') @@ -1029,13 +1041,18 @@ class MilvusConnection: combined_text += " [三元组] " + "; ".join(triplet_texts) debug( f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") + timing_stats["triplet_text_combine"] = time.time() - triplet_text_start + debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒") + embedding_start = time.time() embeddings = await self._get_embeddings([combined_text]) query_vector = embeddings[0] debug(f"拼接文本向量维度: {len(query_vector)}") + timing_stats["embedding_generation"] = time.time() - embedding_start + debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒") + search_start = time.time() search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) expr = f"userid == '{userid}' and ({kb_expr})" debug(f"搜索表达式: {expr}") @@ -1053,7 +1070,9 @@ class MilvusConnection: ) except Exception as e: error(f"向量搜索失败: {str(e)}\n{exception()}") - return [] + return {"results": [], "timing": timing_stats} + timing_stats["vector_search"] = time.time() - search_start + debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") search_results = [] for hits in results: @@ -1078,31 +1097,40 @@ class MilvusConnection: unique_results = [] seen_texts = set() + dedup_start = time.time() for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): if result['text'] not in seen_texts: unique_results.append(result) seen_texts.add(result['text']) + timing_stats["deduplication"] = time.time() - dedup_start + debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") if use_rerank and unique_results: + rerank_start = time.time() debug("开始重排序") - unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit + unique_results = await self._rerank_results(combined_text, unique_results, limit) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timing_stats["reranking"] = time.time() - rerank_start + debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - info(f"融合搜索完成,返回 {len(unique_results)} 条结果") - return unique_results[:limit] + timing_stats["total_time"] = time.time() - start_time + info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + return {"results": unique_results[:limit], "timing": timing_stats} except Exception as e: error(f"融合搜索失败: {str(e)}\n{exception()}") - return [] + return {"results": [], "timing": timing_stats} async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5, - offset: int = 0, use_rerank: bool = True) -> List[Dict]: + offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]: """纯向量搜索,基于查询文本在指定知识库中搜索相关文本块""" + start_time = time.time() # 记录开始时间 collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + timing_stats = {} # 记录各步骤耗时 try: info( f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") @@ -1133,22 +1161,27 @@ class MilvusConnection: if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") - return [] + return {"results": [], "timing": timing_stats} try: collection = Collection(collection_name) collection.load() debug(f"加载集合: {collection_name}") + timing_stats["collection_load"] = time.time() - start_time + debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") - raise RuntimeError(f"加载集合失败: {str(e)}") + return {"results": [], "timing": timing_stats} + embedding_start = time.time() embeddings = await self._get_embeddings([query]) query_vector = embeddings[0] debug(f"查询向量维度: {len(query_vector)}") + timing_stats["embedding_generation"] = time.time() - embedding_start + debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒") + search_start = time.time() search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) expr = f"userid == '{userid}' and ({kb_id_expr})" debug(f"搜索表达式: {expr}") @@ -1166,7 +1199,9 @@ class MilvusConnection: ) except Exception as e: error(f"搜索失败: {str(e)}\n{exception()}") - raise RuntimeError(f"搜索失败: {str(e)}") + return {"results": [], "timing": timing_stats} + timing_stats["vector_search"] = time.time() - search_start + debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒") search_results = [] for hits in results: @@ -1189,96 +1224,100 @@ class MilvusConnection: debug( f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") + dedup_start = time.time() unique_results = [] seen_texts = set() for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): if result['text'] not in seen_texts: unique_results.append(result) seen_texts.add(result['text']) + timing_stats["deduplication"] = time.time() - dedup_start + debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒") info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") if use_rerank and unique_results: + rerank_start = time.time() debug("开始重排序") unique_results = await self._rerank_results(query, unique_results, limit) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) + timing_stats["reranking"] = time.time() - rerank_start + debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果") - return unique_results[:limit] + timing_stats["total_time"] = time.time() - start_time + info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒") + return {"results": unique_results[:limit], "timing": timing_stats} except Exception as e: error(f"纯向量搜索失败: {str(e)}\n{exception()}") - return [] + return {"results": [], "timing": timing_stats} - async def list_user_files(self, userid: str) -> List[Dict]: - """根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询""" + async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]: + """列出用户的所有知识库及其文件,按 knowledge_base_id 分组""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - info(f"开始查询用户文件列表: userid={userid}") + info(f"列出用户文件: userid={userid}, db_type={db_type}") if not userid: raise ValueError("userid 不能为空") - if "_" in userid: - raise ValueError("userid 不能包含下划线") - if len(userid) > 100: - raise ValueError("userid 长度超出限制") + if "_" in userid or (db_type and "_" in db_type): + raise ValueError("userid 和 db_type 不能包含下划线") + if (db_type and len(db_type) > 100) or len(userid) > 100: + raise ValueError("userid 或 db_type 的长度超出限制") - collections = utility.list_collections() - collections = [c for c in collections if c.startswith("ragdb")] - if not collections: - debug("未找到任何 ragdb 开头的集合") - return [] - debug(f"找到集合: {collections}") + if not utility.has_collection(collection_name): + debug(f"集合 {collection_name} 不存在") + return {} - file_list = [] - seen_files = set() - for collection_name in collections: - db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else "" - debug(f"处理集合: {collection_name}, db_type={db_type}") + try: + collection = Collection(collection_name) + collection.load() + debug(f"加载集合: {collection_name}") + except Exception as e: + error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") + return {} - try: - collection = Collection(collection_name) - collection.load() - debug(f"加载集合: {collection_name}") - except Exception as e: - error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") - continue + expr = f"userid == '{userid}'" + debug(f"查询表达式: {expr}") - try: - results = collection.query( - expr=f"userid == '{userid}'", - output_fields=["filename", "file_path", "upload_time", "file_type"], - limit=1000 - ) - debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块") - except Exception as e: - error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}") - continue + try: + results = collection.query( + expr=expr, + output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"], + limit=1000 + ) + except Exception as e: + error(f"查询用户文件失败: {str(e)}\n{exception()}") + return {} - for result in results: - filename = result.get("filename") - file_path = result.get("file_path") - upload_time = result.get("upload_time") - file_type = result.get("file_type") - if (filename, file_path) not in seen_files: - seen_files.add((filename, file_path)) - file_list.append({ - "filename": filename, - "file_path": file_path, - "db_type": db_type, - "upload_time": upload_time, - "file_type": file_type - }) - debug( - f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}") + files_by_kb = {} + seen_document_ids = set() + for result in results: + document_id = result.get("document_id") + kb_id = result.get("knowledge_base_id") + if document_id not in seen_document_ids: + seen_document_ids.add(document_id) + file_info = { + "document_id": document_id, + "filename": result.get("filename"), + "file_path": result.get("file_path"), + "upload_time": result.get("upload_time"), + "file_type": result.get("file_type"), + "knowledge_base_id": kb_id + } + if kb_id not in files_by_kb: + files_by_kb[kb_id] = [] + files_by_kb[kb_id].append(file_info) + debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}") - info(f"返回 {len(file_list)} 个文件") - return sorted(file_list, key=lambda x: x["upload_time"], reverse=True) + info(f"找到 {len(seen_document_ids)} 个文件,userid={userid}, 知识库数量={len(files_by_kb)}") + return files_by_kb except Exception as e: - error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}") - return [] + error(f"列出用户文件失败: {str(e)}\n{exception()}") + return {} connection_register('Milvus', MilvusConnection) info("MilvusConnection registered") \ No newline at end of file diff --git a/llmengine/mrebeltriple.py b/llmengine/mrebeltriple.py index 6703a42..94a3d33 100644 --- a/llmengine/mrebeltriple.py +++ b/llmengine/mrebeltriple.py @@ -1,25 +1,25 @@ import os import torch import re +import traceback from transformers import AutoModelForSeq2SeqLM, AutoTokenizer -import logging +from appPublic.log import debug, error, warning, info +from appPublic.worker import awaitify from base_triple import BaseTripleExtractor, llm_register -logger = logging.getLogger(__name__) - class MRebelTripleExtractor(BaseTripleExtractor): def __init__(self, model_path: str): super().__init__(model_path) try: - logger.debug(f"Loading tokenizer from {model_path}") + debug(f"Loading tokenizer from {model_path}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) - logger.debug(f"Loading model from {model_path}") + debug(f"Loading model from {model_path}") self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) self.device = self.use_mps_if_possible() self.triplet_id = self.tokenizer.convert_tokens_to_ids("") - logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}") + debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}") except Exception as e: - logger.error(f"Failed to load mREBEL model: {str(e)}") + error(f"Failed to load mREBEL model: {str(e)}") raise RuntimeError(f"Failed to load mREBEL model: {str(e)}") self.gen_kwargs = { @@ -47,13 +47,13 @@ class MRebelTripleExtractor(BaseTripleExtractor): current_chunk = sentence if current_chunk: chunks.append(current_chunk) - logger.debug(f"Text split into: {len(chunks)} chunks") + debug(f"Text split into: {len(chunks)} chunks") return chunks def extract_triplets_typed(self, text: str) -> list: """Parse mREBEL generated text for triplets.""" triplets = [] - logger.debug(f"Raw generated text: {text}") + debug(f"Raw generated text: {text}") tokens = [] in_tag = False @@ -77,7 +77,7 @@ class MRebelTripleExtractor(BaseTripleExtractor): special_tokens = ["", "", "", "tp_XX", "__en__", "__zh__", "zh_CN"] tokens = [t for t in tokens if t not in special_tokens and t] - logger.debug(f"Processed tokens: {tokens}") + debug(f"Processed tokens: {tokens}") i = 0 while i < len(tokens): @@ -96,7 +96,7 @@ class MRebelTripleExtractor(BaseTripleExtractor): 'tail': entity2.strip(), 'tail_type': type2 }) - logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})") + debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})") i += 6 else: i += 1 @@ -110,11 +110,11 @@ class MRebelTripleExtractor(BaseTripleExtractor): raise ValueError("Text cannot be empty") text_chunks = self.split_document(text, max_chunk_size=150) - logger.debug(f"Text split into {len(text_chunks)} chunks") + debug(f"Text split into {len(text_chunks)} chunks") all_triplets = [] for i, chunk in enumerate(text_chunks): - logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...") + debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...") model_inputs = self.tokenizer( chunk, @@ -128,17 +128,17 @@ class MRebelTripleExtractor(BaseTripleExtractor): generated_tokens = self.model.generate( model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], - **self.gen_kwargs, + **self.gen_kwargs ) decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) for idx, sentence in enumerate(decoded_preds): - logger.debug(f"Chunk {i + 1} generated text: {sentence}") + debug(f"Chunk {i + 1} generated text: {sentence}") triplets = self.extract_triplets_typed(sentence) if triplets: - logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets") + debug(f"Chunk {i + 1} extracted {len(triplets)} triplets") all_triplets.extend(triplets) except Exception as e: - logger.warning(f"Error processing chunk {i + 1}: {str(e)}") + warning(f"Error processing chunk {i + 1}: {str(e)}") continue unique_triplets = [] @@ -149,13 +149,12 @@ class MRebelTripleExtractor(BaseTripleExtractor): seen.add(identifier) unique_triplets.append(t) - logger.info(f"Extracted {len(unique_triplets)} unique triplets") + info(f"Extracted {len(unique_triplets)} unique triplets") return unique_triplets except Exception as e: - logger.error(f"Failed to extract triplets: {str(e)}") - import traceback - logger.debug(traceback.format_exc()) + error(f"Failed to extract triplets: {str(e)}") + debug(f"Traceback: {traceback.format_exc()}") return [] llm_register("mrebel-large", MRebelTripleExtractor) \ No newline at end of file