增加数据库服务化基于数据库+知识图谱召回文本块、列出用户不同知识库中的文档功能

This commit is contained in:
wangmeihua 2025-07-07 17:59:35 +08:00
parent 1c783461bd
commit 564bddfcde
6 changed files with 399 additions and 251 deletions

View File

@ -80,25 +80,43 @@ data: {
"use_rerank": true "use_rerank": true
} }
response: response:
- Success: HTTP 200, [ - Success: HTTP 200, {
{ "status": "success",
"text": "<完整文本内容>", "results": [
"distance": 0.95, {
"source": "fused_query_with_triplets", "text": "<完整文本内容>",
"rerank_score": 0.92, // use_rerank=true "distance": 0.95,
"metadata": { "source": "fused_query_with_triplets",
"userid": "user1", "rerank_score": 0.92, // use_rerank=true
"document_id": "<uuid>", "metadata": {
"filename": "file.txt", "userid": "user1",
"file_path": "/path/to/file.txt", "document_id": "<uuid>",
"upload_time": "<iso_timestamp>", "filename": "file.txt",
"file_type": "txt" "file_path": "/path/to/file.txt",
} "upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"entity_extraction": <float>, // 实体提取耗时
"triplet_matching": <float>, // 三元组匹配耗时
"triplet_text_combine": <float>, // 拼接三元组文本耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
}, },
... "collection_name": "ragdb" or "ragdb_textdb"
] }
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"} - Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
6. Search Query Endpoint: 6. Search Query Endpoint:
path: /v1/searchquery path: /v1/searchquery
method: POST method: POST
@ -113,45 +131,83 @@ data: {
"use_rerank": true "use_rerank": true
} }
response: response:
- Success: HTTP 200, [ - Success: HTTP 200, {
{ "status": "success",
"text": "<完整文本内容>", "results": [
"distance": 0.95, {
"source": "vector_query", "text": "<完整文本内容>",
"rerank_score": 0.92, // use_rerank=true "distance": 0.95,
"metadata": { "source": "vector_query",
"userid": "user1", "rerank_score": 0.92, // use_rerank=true
"document_id": "<uuid>", "metadata": {
"filename": "file.txt", "userid": "user1",
"file_path": "/path/to/file.txt", "document_id": "<uuid>",
"upload_time": "<iso_timestamp>", "filename": "file.txt",
"file_type": "txt" "file_path": "/path/to/file.txt",
} "upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
}, },
... "collection_name": "ragdb" or "ragdb_textdb"
] }
- Error: HTTP 400, {"status": "error", "message": "<error message>"} - Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
7. List User Files Endpoint: 7. List User Files Endpoint:
path: /v1/listuserfiles path: /v1/listuserfiles
method: POST method: POST
headers: {"Content-Type": "application/json"} headers: {"Content-Type": "application/json"}
data: { data: {
"userid": "testuser2" "userid": "user1",
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
} }
response: response:
- Success: HTTP 200, [ - Success: HTTP 200, {
{ "status": "success",
"filename": "file.txt", "files_by_knowledge_base": {
"file_path": "/path/to/file.txt", "kb123": [
"db_type": "textdb", {
"upload_time": "<iso_timestamp>", "document_id": "<uuid>",
"file_type": "txt" "filename": "file1.txt",
"file_path": "/path/to/file1.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt",
"knowledge_base_id": "kb123"
},
...
],
"kb456": [
{
"document_id": "<uuid>",
"filename": "file2.pdf",
"file_path": "/path/to/file2.pdf",
"upload_time": "<iso_timestamp>",
"file_type": "pdf",
"knowledge_base_id": "kb456"
},
...
]
}, },
... "collection_name": "ragdb" or "ragdb_textdb"
] }
- Error: HTTP 400, {"status": "error", "message": "<error message>"} - Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
8. Connection Endpoint (for compatibility): 8. Connection Endpoint (for compatibility):
path: /v1/connection path: /v1/connection
method: POST method: POST
@ -378,7 +434,13 @@ async def fused_search_query(request, params_kw, *params, **kw):
"use_rerank": use_rerank "use_rerank": use_rerank
}) })
debug(f'{result=}') debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e: except Exception as e:
error(f'融合搜索失败: {str(e)}') error(f'融合搜索失败: {str(e)}')
return web.json_response({ return web.json_response({
@ -396,7 +458,7 @@ async def search_query(request, params_kw, *params, **kw):
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids') knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit') limit = params_kw.get('limit')
offset = params_kw.get('offset') offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True) use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
@ -417,7 +479,13 @@ async def search_query(request, params_kw, *params, **kw):
"use_rerank": use_rerank "use_rerank": use_rerank
}) })
debug(f'{result=}') debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e: except Exception as e:
error(f'纯向量搜索失败: {str(e)}') error(f'纯向量搜索失败: {str(e)}')
return web.json_response({ return web.json_response({
@ -431,23 +499,33 @@ async def list_user_files(request, params_kw, *params, **kw):
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
userid = params_kw.get('userid') userid = params_kw.get('userid')
if not userid: db_type = params_kw.get('db_type', '')
debug(f'userid 未提供') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
return web.json_response({
"status": "error",
"message": "userid 参数未提供"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
try: try:
if not userid:
debug(f'userid 未提供')
return web.json_response({
"status": "error",
"message": "userid 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("list_user_files", { result = await engine.handle_connection("list_user_files", {
"userid": userid "userid": userid,
"db_type": db_type
}) })
debug(f'{result=}') debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) response = {
"status": "success",
"files_by_knowledge_base": result,
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e: except Exception as e:
error(f'查询用户文件列表失败: {str(e)}') error(f'列出用户文件失败: {str(e)}')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"message": str(e) "message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def handle_connection(request, params_kw, *params, **kw): async def handle_connection(request, params_kw, *params, **kw):

View File

@ -45,6 +45,7 @@ def init():
rf = RegisterFunction() rf = RegisterFunction()
rf.register('entities', entities) rf.register('entities', entities)
rf.register('docs', docs) rf.register('docs', docs)
debug("注册路由: entities, docs")
async def docs(request, params_kw, *params, **kw): async def docs(request, params_kw, *params, **kw):
return helptext return helptext
@ -53,12 +54,11 @@ async def entities(request, params_kw, *params, **kw):
debug(f'{params_kw.query=}') debug(f'{params_kw.query=}')
se = ServerEnv() se = ServerEnv()
engine = se.engine engine = se.engine
f = awaitify(engine.extract_entities)
query = params_kw.query query = params_kw.query
if query is None: if query is None:
e = exception(f'query is None') e = exception(f'query is None')
raise e raise e
entities = await f(query) entities = await engine.extract_entities(query)
debug(f'{entities=}, type(entities)') debug(f'{entities=}, type(entities)')
return { return {
"object": "list", "object": "list",

View File

@ -1,19 +1,23 @@
import logging
import os import os
import re import re
from py2neo import Graph, Node, Relationship from py2neo import Graph, Node, Relationship
from typing import Set, List, Dict, Tuple from typing import Set, List, Dict, Tuple
from appPublic.jsonConfig import getConfig
from appPublic.log import debug, error, info, exception
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class KnowledgeGraph: class KnowledgeGraph:
def __init__(self, triples: List[Dict], document_id: str): def __init__(self, triples: List[Dict], document_id: str, knowledge_base_id: str, userid: str):
self.triples = triples self.triples = triples
self.document_id = document_id self.document_id = document_id
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh')) self.knowledge_base_id = knowledge_base_id
logger.info(f"开始构建知识图谱document_id: {self.document_id}, 三元组数量: {len(triples)}") self.userid = userid
config = getConfig()
self.neo4j_uri = config['neo4j']['uri']
self.neo4j_user = config['neo4j']['user']
self.neo4j_password = config['neo4j']['password']
self.g = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
info(f"开始构建知识图谱document_id: {self.document_id}, knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid}, 三元组数量: {len(triples)}")
def _normalize_label(self, entity_type: str) -> str: def _normalize_label(self, entity_type: str) -> str:
"""规范化实体类型为 Neo4j 标签""" """规范化实体类型为 Neo4j 标签"""
@ -25,28 +29,29 @@ class KnowledgeGraph:
return label or 'Entity' return label or 'Entity'
def _clean_relation(self, relation: str) -> Tuple[str, str]: def _clean_relation(self, relation: str) -> Tuple[str, str]:
"""清洗关系,返回 (rel_type, rel_name)""" """清洗关系,返回 (rel_type, rel_name),确保 rel_type 合法"""
relation = relation.strip() relation = relation.strip()
if not relation: if not relation:
return 'RELATED_TO', '相关' return 'RELATED_TO', '相关'
if relation.startswith('<') and relation.endswith('>'):
cleaned_relation = relation[1:-1] cleaned_relation = re.sub(r'[^\w\s]', '', relation).strip()
rel_name = cleaned_relation if not cleaned_relation:
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper() return 'RELATED_TO', '相关'
else:
rel_name = relation if 'instance of' in relation.lower():
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper() return 'INSTANCE_OF', '实例'
if 'instance of' in relation.lower(): elif 'subclass of' in relation.lower():
rel_type = 'INSTANCE_OF' return 'SUBCLASS_OF', '子类'
rel_name = '实例' elif 'part of' in relation.lower():
elif 'subclass of' in relation.lower(): return 'PART_OF', '部分'
rel_type = 'SUBCLASS_OF'
rel_name = '子类' rel_type = re.sub(r'\s+', '_', cleaned_relation).upper()
elif 'part of' in relation.lower(): if rel_type and rel_type[0].isdigit():
rel_type = 'PART_OF' rel_type = f'REL_{rel_type}'
rel_name = '部分' if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', rel_type):
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})") debug(f"非法关系类型 '{rel_type}',替换为 'RELATED_TO'")
return rel_type, rel_name return 'RELATED_TO', relation
return rel_type, relation
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]: def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
"""从三元组列表中读取节点和关系""" """从三元组列表中读取节点和关系"""
@ -57,14 +62,14 @@ class KnowledgeGraph:
try: try:
for triple in self.triples: for triple in self.triples:
if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']): if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']):
logger.warning(f"无效三元组: {triple}") debug(f"无效三元组: {triple}")
continue continue
head, relation, tail, head_type, tail_type = ( head, relation, tail, head_type, tail_type = (
triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type'] triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type']
) )
head_label = self._normalize_label(head_type) head_label = self._normalize_label(head_type)
tail_label = self._normalize_label(tail_type) tail_label = self._normalize_label(tail_type)
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}") debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
if head_label not in nodes_by_label: if head_label not in nodes_by_label:
nodes_by_label[head_label] = set() nodes_by_label[head_label] = set()
@ -92,33 +97,44 @@ class KnowledgeGraph:
'tail_type': tail_type 'tail_type': tail_type
}) })
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())}") info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())}")
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())}") info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())}")
return nodes_by_label, relations_by_type, triples return nodes_by_label, relations_by_type, triples
except Exception as e: except Exception as e:
logger.error(f"读取三元组失败: {str(e)}") error(f"读取三元组失败: {str(e)}")
raise RuntimeError(f"读取三元组失败: {str(e)}") raise RuntimeError(f"读取三元组失败: {str(e)}")
def create_node(self, label: str, nodes: Set[str]): def create_node(self, label: str, nodes: Set[str]):
"""创建节点,包含 document_id 属性""" """创建节点,包含 document_id、knowledge_base_id 和 userid 属性"""
count = 0 count = 0
for node_name in nodes: for node_name in nodes:
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n" query = (
f"MATCH (n:{label} {{name: $name, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}) RETURN n"
)
try: try:
if self.g.run(query).data(): if self.g.run(query, name=node_name, doc_id=self.document_id,
kb_id=self.knowledge_base_id, userid=self.userid).data():
continue continue
node = Node(label, name=node_name, document_id=self.document_id) node = Node(
label,
name=node_name,
document_id=self.document_id,
knowledge_base_id=self.knowledge_base_id,
userid=self.userid
)
self.g.create(node) self.g.create(node)
count += 1 count += 1
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})") debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id}, "
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
except Exception as e: except Exception as e:
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}") error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
logger.info(f"创建 {label} 节点: {count}/{len(nodes)}") info(f"创建 {label} 节点: {count}/{len(nodes)}")
return count return count
def create_relationship(self, rel_type: str, relations: List[Dict]): def create_relationship(self, rel_type: str, relations: List[Dict]):
"""创建关系""" """创建关系,包含 document_id、knowledge_base_id 和 userid 属性"""
count = 0 count = 0
total = len(relations) total = len(relations)
seen_edges = set() seen_edges = set()
@ -132,17 +148,23 @@ class KnowledgeGraph:
seen_edges.add(edge_key) seen_edges.add(edge_key)
query = ( query = (
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), " f"MATCH (p:{head_label} {{name: $head, document_id: $doc_id, "
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) " f"knowledge_base_id: $kb_id, userid: $userid}}), "
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)" f"(q:{tail_label} {{name: $tail, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}) "
f"CREATE (p)-[r:{rel_type} {{name: $rel_name, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}]->(q)"
) )
try: try:
self.g.run(query) self.g.run(query, head=head, tail=tail, rel_name=rel_name,
doc_id=self.document_id, kb_id=self.knowledge_base_id,
userid=self.userid)
count += 1 count += 1
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})") debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id}, "
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
except Exception as e: except Exception as e:
logger.error(f"创建关系失败: {query}, 错误: {str(e)}") error(f"创建关系失败: {query}, 错误: {str(e)}")
logger.info(f"创建 {rel_type} 关系: {count}/{total}") info(f"创建 {rel_type} 关系: {count}/{total}")
return count return count
def create_graphnodes(self): def create_graphnodes(self):
@ -151,7 +173,7 @@ class KnowledgeGraph:
total = 0 total = 0
for label, nodes in nodes_by_label.items(): for label, nodes in nodes_by_label.items():
total += self.create_node(label, nodes) total += self.create_node(label, nodes)
logger.info(f"总计创建节点: {total}") info(f"总计创建节点: {total}")
return total return total
def create_graphrels(self): def create_graphrels(self):
@ -160,15 +182,16 @@ class KnowledgeGraph:
total = 0 total = 0
for rel_type, relations in relations_by_type.items(): for rel_type, relations in relations_by_type.items():
total += self.create_relationship(rel_type, relations) total += self.create_relationship(rel_type, relations)
logger.info(f"总计创建关系: {total}") info(f"总计创建关系: {total}")
return total return total
def export_data(self): def export_data(self):
"""导出节点到文件,包含 document_id""" """导出节点到文件,包含 document_id、knowledge_base_id 和 userid"""
nodes_by_label, _, _ = self.read_nodes() nodes_by_label, _, _ = self.read_nodes()
os.makedirs('dict', exist_ok=True) os.makedirs('dict', exist_ok=True)
for label, nodes in nodes_by_label.items(): for label, nodes in nodes_by_label.items():
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f: with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes))) f.write('\n'.join(f"{name}\t{self.document_id}\t{self.knowledge_base_id}\t{self.userid}"
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)}") for name in sorted(nodes)))
info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)}")
return return

View File

@ -1,11 +1,9 @@
# Requires ltp>=0.2.0
from ltp import LTP from ltp import LTP
from typing import List from typing import List
import logging from appPublic.log import debug, info, error
from appPublic.worker import awaitify
from llmengine.base_entity import BaseLtp, ltp_register from llmengine.base_entity import BaseLtp, ltp_register
import asyncio
logger = logging.getLogger(__name__)
class LtpEntity(BaseLtp): class LtpEntity(BaseLtp):
def __init__(self, model_id): def __init__(self, model_id):
@ -14,7 +12,7 @@ class LtpEntity(BaseLtp):
self.model_id = model_id self.model_id = model_id
self.model_name = model_id.split('/')[-1] self.model_name = model_id.split('/')[-1]
def extract_entities(self, query: str) -> List[str]: async def extract_entities(self, query: str) -> List[str]:
""" """
从查询文本中抽取实体包括 从查询文本中抽取实体包括
- LTP NER 识别的实体所有类型 - LTP NER 识别的实体所有类型
@ -26,7 +24,18 @@ class LtpEntity(BaseLtp):
if not query: if not query:
raise ValueError("查询文本不能为空") raise ValueError("查询文本不能为空")
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"]) # 定义同步 pipeline 函数,正确传递 tasks 参数
def sync_pipeline(query, tasks):
return self.ltp.pipeline([query], tasks=tasks)
# 使用 run_in_executor 运行同步 pipeline
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: sync_pipeline(query, ["cws", "pos", "ner"])
)
# 解析结果
words = result.cws[0] words = result.cws[0]
pos_list = result.pos[0] pos_list = result.pos[0]
ner = result.ner[0] ner = result.ner[0]
@ -34,7 +43,7 @@ class LtpEntity(BaseLtp):
entities = [] entities = []
subword_set = set() subword_set = set()
logger.debug(f"NER 结果: {ner}") debug(f"NER 结果: {ner}")
for entity_type, entity, start, end in ner: for entity_type, entity, start, end in ner:
entities.append(entity) entities.append(entity)
@ -49,13 +58,13 @@ class LtpEntity(BaseLtp):
if combined: if combined:
entities.append(combined) entities.append(combined)
subword_set.update(combined_words) subword_set.update(combined_words)
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}") debug(f"合并连续名词: {combined}, 子词: {combined_words}")
combined = "" combined = ""
combined_words = [] combined_words = []
else: else:
combined = "" combined = ""
combined_words = [] combined_words = []
logger.debug(f"连续名词子词集合: {subword_set}") debug(f"连续名词子词集合: {subword_set}")
for word, pos in zip(words, pos_list): for word, pos in zip(words, pos_list):
if pos == 'n' and word not in subword_set: if pos == 'n' and word not in subword_set:
@ -66,11 +75,11 @@ class LtpEntity(BaseLtp):
entities.append(word) entities.append(word)
unique_entities = list(dict.fromkeys(entities)) unique_entities = list(dict.fromkeys(entities))
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}") info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities return unique_entities
except Exception as e: except Exception as e:
logger.error(f"实体抽取失败: {str(e)}") error(f"实体抽取失败: {str(e)}")
return [] raise # 抛出异常以便调试,而不是返回空列表
ltp_register('LTP', LtpEntity) ltp_register('LTP', LtpEntity)

View File

@ -16,6 +16,7 @@ from llmengine.kgc import KnowledgeGraph
import numpy as np import numpy as np
from py2neo import Graph from py2neo import Graph
from scipy.spatial.distance import cosine from scipy.spatial.distance import cosine
import time
# 嵌入缓存 # 嵌入缓存
EMBED_CACHE = {} EMBED_CACHE = {}
@ -182,7 +183,7 @@ class MilvusConnection:
if not userid: if not userid:
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name, return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
"document_id": "", "status_code": 400} "document_id": "", "status_code": 400}
return await self.list_user_files(userid) return await self._list_user_files(userid)
else: else:
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name, return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
"document_id": "", "status_code": 400} "document_id": "", "status_code": 400}
@ -814,20 +815,24 @@ class MilvusConnection:
error(f"实体识别服务调用失败: {str(e)}\n{exception()}") error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
return [] return []
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]: async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]:
"""匹配查询实体与 Neo4j 中的三元组""" """匹配查询实体与 Neo4j 中的三元组"""
start_time = time.time() # 记录开始时间
matched_triplets = [] matched_triplets = []
ENTITY_SIMILARITY_THRESHOLD = 0.8 ENTITY_SIMILARITY_THRESHOLD = 0.8
try: try:
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
debug(f"已连接到 Neo4j: {self.neo4j_uri}") debug(f"已连接到 Neo4j: {self.neo4j_uri}")
neo4j_connect_time = time.time() - start_time
debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f}")
matched_names = set() matched_names = set()
entity_match_start = time.time()
for entity in query_entities: for entity in query_entities:
normalized_entity = entity.lower().strip() normalized_entity = entity.lower().strip()
query = """ query = """
MATCH (n {document_id: $document_id}) MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE toLower(n.name) CONTAINS $entity WHERE toLower(n.name) CONTAINS $entity
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7 OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
@ -835,24 +840,27 @@ class MilvusConnection:
LIMIT 100 LIMIT 100
""" """
try: try:
results = graph.run(query, document_id=document_id, entity=normalized_entity).data() results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data()
for record in results: for record in results:
matched_names.add(record['n.name']) matched_names.add(record['n.name'])
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})") debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
except Exception as e: except Exception as e:
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}") debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
continue continue
entity_match_time = time.time() - entity_match_start
debug(f"实体匹配耗时: {entity_match_time:.3f}")
triplets = [] triplets = []
if matched_names: if matched_names:
triplet_query_start = time.time()
query = """ query = """
MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id}) MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE h.name IN $matched_names OR t.name IN $matched_names WHERE h.name IN $matched_names OR t.name IN $matched_names
RETURN h.name AS head, r.name AS type, t.name AS tail RETURN h.name AS head, r.name AS type, t.name AS tail
LIMIT 100 LIMIT 100
""" """
try: try:
results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data() results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data()
seen = set() seen = set()
for record in results: for record in results:
head, type_, tail = record['head'], record['type'], record['tail'] head, type_, tail = record['head'], record['type'], record['tail']
@ -866,22 +874,28 @@ class MilvusConnection:
'head_type': '', 'head_type': '',
'tail_type': '' 'tail_type': ''
}) })
debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}") debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}")
except Exception as e: except Exception as e:
error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}") error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}\n{exception()}")
return [] return []
triplet_query_time = time.time() - triplet_query_start
debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f}")
if not triplets: if not triplets:
debug(f"文档 document_id={document_id} 无匹配三元组") debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组")
return [] return []
embedding_start = time.time()
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets] texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
embeddings = await self._get_embeddings(texts_to_embed) embeddings = await self._get_embeddings(texts_to_embed)
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)} entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)} head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)} tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails") debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails")
embedding_time = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {embedding_time:.3f}")
similarity_start = time.time()
for entity in query_entities: for entity in query_entities:
entity_vec = entity_vectors[entity] entity_vec = entity_vectors[entity]
for d_triplet in triplets: for d_triplet in triplets:
@ -894,6 +908,8 @@ class MilvusConnection:
matched_triplets.append(d_triplet) matched_triplets.append(d_triplet)
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
similarity_time = time.time() - similarity_start
debug(f"相似度计算耗时: {similarity_time:.3f}")
unique_matched = [] unique_matched = []
seen = set() seen = set()
@ -903,6 +919,8 @@ class MilvusConnection:
seen.add(identifier) seen.add(identifier)
unique_matched.append(t) unique_matched.append(t)
total_time = time.time() - start_time
debug(f"_match_triplets 总耗时: {total_time:.3f}")
info(f"找到 {len(unique_matched)} 个匹配的三元组") info(f"找到 {len(unique_matched)} 个匹配的三元组")
return unique_matched return unique_matched
@ -957,9 +975,11 @@ class MilvusConnection:
return results return results
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5, async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]: offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
"""融合搜索,将查询与所有三元组拼接后向量化搜索""" """融合搜索,将查询与所有三元组拼接后向量化搜索"""
start_time = time.time() # 记录开始时间
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {} # 记录各步骤耗时
try: try:
info( info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
@ -975,46 +995,38 @@ class MilvusConnection:
if not utility.has_collection(collection_name): if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在") debug(f"集合 {collection_name} 不存在")
return [] return {"results": [], "timing": timing_stats}
try: try:
collection = Collection(collection_name) collection = Collection(collection_name)
collection.load() collection.load()
debug(f"加载集合: {collection_name}") debug(f"加载集合: {collection_name}")
timing_stats["collection_load"] = time.time() - start_time
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f}")
except Exception as e: except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return [] return {"results": [], "timing": timing_stats}
entity_extract_start = time.time()
query_entities = await self._extract_entities(query) query_entities = await self._extract_entities(query)
debug(f"提取实体: {query_entities}") timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
documents = []
all_triplets = [] all_triplets = []
triplet_match_start = time.time()
for kb_id in knowledge_base_ids: for kb_id in knowledge_base_ids:
debug(f"处理知识库: {kb_id}") debug(f"处理知识库: {kb_id}")
matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id)
debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)}")
all_triplets.extend(matched_triplets)
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
results = collection.query( if not all_triplets:
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'", debug("未找到任何匹配的三元组")
output_fields=["document_id", "filename", "knowledge_base_id"], return {"results": [], "timing": timing_stats}
limit=100 # 查询足够多的文档以支持后续过滤
)
if not results:
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
continue
documents.extend(results)
for doc in results:
document_id = doc["document_id"]
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)}")
all_triplets.extend(matched_triplets)
if not documents:
debug("未找到任何有效文档")
return []
info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
triplet_text_start = time.time()
triplet_texts = [] triplet_texts = []
for triplet in all_triplets: for triplet in all_triplets:
head = triplet.get('head', '') head = triplet.get('head', '')
@ -1029,13 +1041,18 @@ class MilvusConnection:
combined_text += " [三元组] " + "; ".join(triplet_texts) combined_text += " [三元组] " + "; ".join(triplet_texts)
debug( debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
embedding_start = time.time()
embeddings = await self._get_embeddings([combined_text]) embeddings = await self._get_embeddings([combined_text])
query_vector = embeddings[0] query_vector = embeddings[0]
debug(f"拼接文本向量维度: {len(query_vector)}") debug(f"拼接文本向量维度: {len(query_vector)}")
timing_stats["embedding_generation"] = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f}")
search_start = time.time()
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_expr})" expr = f"userid == '{userid}' and ({kb_expr})"
debug(f"搜索表达式: {expr}") debug(f"搜索表达式: {expr}")
@ -1053,7 +1070,9 @@ class MilvusConnection:
) )
except Exception as e: except Exception as e:
error(f"向量搜索失败: {str(e)}\n{exception()}") error(f"向量搜索失败: {str(e)}\n{exception()}")
return [] return {"results": [], "timing": timing_stats}
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
search_results = [] search_results = []
for hits in results: for hits in results:
@ -1078,31 +1097,40 @@ class MilvusConnection:
unique_results = [] unique_results = []
seen_texts = set() seen_texts = set()
dedup_start = time.time()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts: if result['text'] not in seen_texts:
unique_results.append(result) unique_results.append(result)
seen_texts.add(result['text']) seen_texts.add(result['text'])
timing_stats["deduplication"] = time.time() - dedup_start
debug(f"去重耗时: {timing_stats['deduplication']:.3f}")
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results: if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序") debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit unique_results = await self._rerank_results(combined_text, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else: else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"融合搜索完成,返回 {len(unique_results)} 条结果") timing_stats["total_time"] = time.time() - start_time
return unique_results[:limit] info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e: except Exception as e:
error(f"融合搜索失败: {str(e)}\n{exception()}") error(f"融合搜索失败: {str(e)}\n{exception()}")
return [] return {"results": [], "timing": timing_stats}
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5, async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]: offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块""" """纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
start_time = time.time() # 记录开始时间
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {} # 记录各步骤耗时
try: try:
info( info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
@ -1133,22 +1161,27 @@ class MilvusConnection:
if not utility.has_collection(collection_name): if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在") debug(f"集合 {collection_name} 不存在")
return [] return {"results": [], "timing": timing_stats}
try: try:
collection = Collection(collection_name) collection = Collection(collection_name)
collection.load() collection.load()
debug(f"加载集合: {collection_name}") debug(f"加载集合: {collection_name}")
timing_stats["collection_load"] = time.time() - start_time
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f}")
except Exception as e: except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
raise RuntimeError(f"加载集合失败: {str(e)}") return {"results": [], "timing": timing_stats}
embedding_start = time.time()
embeddings = await self._get_embeddings([query]) embeddings = await self._get_embeddings([query])
query_vector = embeddings[0] query_vector = embeddings[0]
debug(f"查询向量维度: {len(query_vector)}") debug(f"查询向量维度: {len(query_vector)}")
timing_stats["embedding_generation"] = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f}")
search_start = time.time()
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_id_expr})" expr = f"userid == '{userid}' and ({kb_id_expr})"
debug(f"搜索表达式: {expr}") debug(f"搜索表达式: {expr}")
@ -1166,7 +1199,9 @@ class MilvusConnection:
) )
except Exception as e: except Exception as e:
error(f"搜索失败: {str(e)}\n{exception()}") error(f"搜索失败: {str(e)}\n{exception()}")
raise RuntimeError(f"搜索失败: {str(e)}") return {"results": [], "timing": timing_stats}
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
search_results = [] search_results = []
for hits in results: for hits in results:
@ -1189,96 +1224,100 @@ class MilvusConnection:
debug( debug(
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
dedup_start = time.time()
unique_results = [] unique_results = []
seen_texts = set() seen_texts = set()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts: if result['text'] not in seen_texts:
unique_results.append(result) unique_results.append(result)
seen_texts.add(result['text']) seen_texts.add(result['text'])
timing_stats["deduplication"] = time.time() - dedup_start
debug(f"去重耗时: {timing_stats['deduplication']:.3f}")
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results: if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序") debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit) unique_results = await self._rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else: else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果") timing_stats["total_time"] = time.time() - start_time
return unique_results[:limit] info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e: except Exception as e:
error(f"纯向量搜索失败: {str(e)}\n{exception()}") error(f"纯向量搜索失败: {str(e)}\n{exception()}")
return [] return {"results": [], "timing": timing_stats}
async def list_user_files(self, userid: str) -> List[Dict]: async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
"""根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询""" """列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
info(f"开始查询用户文件列表: userid={userid}") info(f"列出用户文件: userid={userid}, db_type={db_type}")
if not userid: if not userid:
raise ValueError("userid 不能为空") raise ValueError("userid 不能为空")
if "_" in userid: if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 不能包含下划线") raise ValueError("userid 和 db_type 不能包含下划线")
if len(userid) > 100: if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 长度超出限制") raise ValueError("userid 或 db_type 的长度超出限制")
collections = utility.list_collections() if not utility.has_collection(collection_name):
collections = [c for c in collections if c.startswith("ragdb")] debug(f"集合 {collection_name} 不存在")
if not collections: return {}
debug("未找到任何 ragdb 开头的集合")
return []
debug(f"找到集合: {collections}")
file_list = [] try:
seen_files = set() collection = Collection(collection_name)
for collection_name in collections: collection.load()
db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else "" debug(f"加载集合: {collection_name}")
debug(f"处理集合: {collection_name}, db_type={db_type}") except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return {}
try: expr = f"userid == '{userid}'"
collection = Collection(collection_name) debug(f"查询表达式: {expr}")
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
continue
try: try:
results = collection.query( results = collection.query(
expr=f"userid == '{userid}'", expr=expr,
output_fields=["filename", "file_path", "upload_time", "file_type"], output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"],
limit=1000 limit=1000
) )
debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块") except Exception as e:
except Exception as e: error(f"查询用户文件失败: {str(e)}\n{exception()}")
error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}") return {}
continue
for result in results: files_by_kb = {}
filename = result.get("filename") seen_document_ids = set()
file_path = result.get("file_path") for result in results:
upload_time = result.get("upload_time") document_id = result.get("document_id")
file_type = result.get("file_type") kb_id = result.get("knowledge_base_id")
if (filename, file_path) not in seen_files: if document_id not in seen_document_ids:
seen_files.add((filename, file_path)) seen_document_ids.add(document_id)
file_list.append({ file_info = {
"filename": filename, "document_id": document_id,
"file_path": file_path, "filename": result.get("filename"),
"db_type": db_type, "file_path": result.get("file_path"),
"upload_time": upload_time, "upload_time": result.get("upload_time"),
"file_type": file_type "file_type": result.get("file_type"),
}) "knowledge_base_id": kb_id
debug( }
f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}") if kb_id not in files_by_kb:
files_by_kb[kb_id] = []
files_by_kb[kb_id].append(file_info)
debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}")
info(f"返回 {len(file_list)} 个文件") info(f"找到 {len(seen_document_ids)} 个文件userid={userid}, 知识库数量={len(files_by_kb)}")
return sorted(file_list, key=lambda x: x["upload_time"], reverse=True) return files_by_kb
except Exception as e: except Exception as e:
error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}") error(f"列出用户文件失败: {str(e)}\n{exception()}")
return [] return {}
connection_register('Milvus', MilvusConnection) connection_register('Milvus', MilvusConnection)
info("MilvusConnection registered") info("MilvusConnection registered")

View File

@ -1,25 +1,25 @@
import os import os
import torch import torch
import re import re
import traceback
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging from appPublic.log import debug, error, warning, info
from appPublic.worker import awaitify
from base_triple import BaseTripleExtractor, llm_register from base_triple import BaseTripleExtractor, llm_register
logger = logging.getLogger(__name__)
class MRebelTripleExtractor(BaseTripleExtractor): class MRebelTripleExtractor(BaseTripleExtractor):
def __init__(self, model_path: str): def __init__(self, model_path: str):
super().__init__(model_path) super().__init__(model_path)
try: try:
logger.debug(f"Loading tokenizer from {model_path}") debug(f"Loading tokenizer from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.debug(f"Loading model from {model_path}") debug(f"Loading model from {model_path}")
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.device = self.use_mps_if_possible() self.device = self.use_mps_if_possible()
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>") self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}") debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to load mREBEL model: {str(e)}") error(f"Failed to load mREBEL model: {str(e)}")
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}") raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
self.gen_kwargs = { self.gen_kwargs = {
@ -47,13 +47,13 @@ class MRebelTripleExtractor(BaseTripleExtractor):
current_chunk = sentence current_chunk = sentence
if current_chunk: if current_chunk:
chunks.append(current_chunk) chunks.append(current_chunk)
logger.debug(f"Text split into: {len(chunks)} chunks") debug(f"Text split into: {len(chunks)} chunks")
return chunks return chunks
def extract_triplets_typed(self, text: str) -> list: def extract_triplets_typed(self, text: str) -> list:
"""Parse mREBEL generated text for triplets.""" """Parse mREBEL generated text for triplets."""
triplets = [] triplets = []
logger.debug(f"Raw generated text: {text}") debug(f"Raw generated text: {text}")
tokens = [] tokens = []
in_tag = False in_tag = False
@ -77,7 +77,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"] special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
tokens = [t for t in tokens if t not in special_tokens and t] tokens = [t for t in tokens if t not in special_tokens and t]
logger.debug(f"Processed tokens: {tokens}") debug(f"Processed tokens: {tokens}")
i = 0 i = 0
while i < len(tokens): while i < len(tokens):
@ -96,7 +96,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
'tail': entity2.strip(), 'tail': entity2.strip(),
'tail_type': type2 'tail_type': type2
}) })
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})") debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
i += 6 i += 6
else: else:
i += 1 i += 1
@ -110,11 +110,11 @@ class MRebelTripleExtractor(BaseTripleExtractor):
raise ValueError("Text cannot be empty") raise ValueError("Text cannot be empty")
text_chunks = self.split_document(text, max_chunk_size=150) text_chunks = self.split_document(text, max_chunk_size=150)
logger.debug(f"Text split into {len(text_chunks)} chunks") debug(f"Text split into {len(text_chunks)} chunks")
all_triplets = [] all_triplets = []
for i, chunk in enumerate(text_chunks): for i, chunk in enumerate(text_chunks):
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...") debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
model_inputs = self.tokenizer( model_inputs = self.tokenizer(
chunk, chunk,
@ -128,17 +128,17 @@ class MRebelTripleExtractor(BaseTripleExtractor):
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
model_inputs["input_ids"], model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"], attention_mask=model_inputs["attention_mask"],
**self.gen_kwargs, **self.gen_kwargs
) )
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for idx, sentence in enumerate(decoded_preds): for idx, sentence in enumerate(decoded_preds):
logger.debug(f"Chunk {i + 1} generated text: {sentence}") debug(f"Chunk {i + 1} generated text: {sentence}")
triplets = self.extract_triplets_typed(sentence) triplets = self.extract_triplets_typed(sentence)
if triplets: if triplets:
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets") debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
all_triplets.extend(triplets) all_triplets.extend(triplets)
except Exception as e: except Exception as e:
logger.warning(f"Error processing chunk {i + 1}: {str(e)}") warning(f"Error processing chunk {i + 1}: {str(e)}")
continue continue
unique_triplets = [] unique_triplets = []
@ -149,13 +149,12 @@ class MRebelTripleExtractor(BaseTripleExtractor):
seen.add(identifier) seen.add(identifier)
unique_triplets.append(t) unique_triplets.append(t)
logger.info(f"Extracted {len(unique_triplets)} unique triplets") info(f"Extracted {len(unique_triplets)} unique triplets")
return unique_triplets return unique_triplets
except Exception as e: except Exception as e:
logger.error(f"Failed to extract triplets: {str(e)}") error(f"Failed to extract triplets: {str(e)}")
import traceback debug(f"Traceback: {traceback.format_exc()}")
logger.debug(traceback.format_exc())
return [] return []
llm_register("mrebel-large", MRebelTripleExtractor) llm_register("mrebel-large", MRebelTripleExtractor)