增加数据库服务化基于数据库+知识图谱召回文本块、列出用户不同知识库中的文档功能

This commit is contained in:
wangmeihua 2025-07-07 17:59:35 +08:00
parent 1c783461bd
commit 564bddfcde
6 changed files with 399 additions and 251 deletions

View File

@ -80,25 +80,43 @@ data: {
"use_rerank": true
}
response:
- Success: HTTP 200, [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "fused_query_with_triplets",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
- Success: HTTP 200, {
"status": "success",
"results": [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "fused_query_with_triplets",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"entity_extraction": <float>, // 实体提取耗时
"triplet_matching": <float>, // 三元组匹配耗时
"triplet_text_combine": <float>, // 拼接三元组文本耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
},
...
]
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"}
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
6. Search Query Endpoint:
path: /v1/searchquery
method: POST
@ -113,45 +131,83 @@ data: {
"use_rerank": true
}
response:
- Success: HTTP 200, [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "vector_query",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
- Success: HTTP 200, {
"status": "success",
"results": [
{
"text": "<完整文本内容>",
"distance": 0.95,
"source": "vector_query",
"rerank_score": 0.92, // use_rerank=true
"metadata": {
"userid": "user1",
"document_id": "<uuid>",
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
}
},
...
],
"timing": {
"collection_load": <float>, // 集合加载耗时
"embedding_generation": <float>, // 嵌入向量生成耗时
"vector_search": <float>, // 向量搜索耗时
"deduplication": <float>, // 去重耗时
"reranking": <float>, // 重排序耗时 use_rerank=true
"total_time": <float> // 总耗时
},
...
]
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
7. List User Files Endpoint:
path: /v1/listuserfiles
method: POST
headers: {"Content-Type": "application/json"}
data: {
"userid": "testuser2"
"userid": "user1",
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
}
response:
- Success: HTTP 200, [
{
"filename": "file.txt",
"file_path": "/path/to/file.txt",
"db_type": "textdb",
"upload_time": "<iso_timestamp>",
"file_type": "txt"
- Success: HTTP 200, {
"status": "success",
"files_by_knowledge_base": {
"kb123": [
{
"document_id": "<uuid>",
"filename": "file1.txt",
"file_path": "/path/to/file1.txt",
"upload_time": "<iso_timestamp>",
"file_type": "txt",
"knowledge_base_id": "kb123"
},
...
],
"kb456": [
{
"document_id": "<uuid>",
"filename": "file2.pdf",
"file_path": "/path/to/file2.pdf",
"upload_time": "<iso_timestamp>",
"file_type": "pdf",
"knowledge_base_id": "kb456"
},
...
]
},
...
]
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
"collection_name": "ragdb" or "ragdb_textdb"
}
- Error: HTTP 400, {
"status": "error",
"message": "<error message>",
"collection_name": "ragdb" or "ragdb_textdb"
}
8. Connection Endpoint (for compatibility):
path: /v1/connection
method: POST
@ -378,7 +434,13 @@ async def fused_search_query(request, params_kw, *params, **kw):
"use_rerank": use_rerank
})
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'融合搜索失败: {str(e)}')
return web.json_response({
@ -396,7 +458,7 @@ async def search_query(request, params_kw, *params, **kw):
db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids')
limit = params_kw.get('limit')
offset = params_kw.get('offset')
offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
@ -417,7 +479,13 @@ async def search_query(request, params_kw, *params, **kw):
"use_rerank": use_rerank
})
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
response = {
"status": "success",
"results": result.get("results", []),
"timing": result.get("timing", {}),
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'纯向量搜索失败: {str(e)}')
return web.json_response({
@ -431,23 +499,33 @@ async def list_user_files(request, params_kw, *params, **kw):
se = ServerEnv()
engine = se.engine
userid = params_kw.get('userid')
if not userid:
debug(f'userid 未提供')
return web.json_response({
"status": "error",
"message": "userid 参数未提供"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
db_type = params_kw.get('db_type', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not userid:
debug(f'userid 未提供')
return web.json_response({
"status": "error",
"message": "userid 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("list_user_files", {
"userid": userid
"userid": userid,
"db_type": db_type
})
debug(f'{result=}')
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
response = {
"status": "success",
"files_by_knowledge_base": result,
"collection_name": collection_name
}
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
except Exception as e:
error(f'查询用户文件列表失败: {str(e)}')
error(f'列出用户文件失败: {str(e)}')
return web.json_response({
"status": "error",
"message": str(e)
"message": str(e),
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def handle_connection(request, params_kw, *params, **kw):

View File

@ -45,6 +45,7 @@ def init():
rf = RegisterFunction()
rf.register('entities', entities)
rf.register('docs', docs)
debug("注册路由: entities, docs")
async def docs(request, params_kw, *params, **kw):
return helptext
@ -53,12 +54,11 @@ async def entities(request, params_kw, *params, **kw):
debug(f'{params_kw.query=}')
se = ServerEnv()
engine = se.engine
f = awaitify(engine.extract_entities)
query = params_kw.query
if query is None:
e = exception(f'query is None')
raise e
entities = await f(query)
entities = await engine.extract_entities(query)
debug(f'{entities=}, type(entities)')
return {
"object": "list",

View File

@ -1,19 +1,23 @@
import logging
import os
import re
from py2neo import Graph, Node, Relationship
from typing import Set, List, Dict, Tuple
from appPublic.jsonConfig import getConfig
from appPublic.log import debug, error, info, exception
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class KnowledgeGraph:
def __init__(self, triples: List[Dict], document_id: str):
def __init__(self, triples: List[Dict], document_id: str, knowledge_base_id: str, userid: str):
self.triples = triples
self.document_id = document_id
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh'))
logger.info(f"开始构建知识图谱document_id: {self.document_id}, 三元组数量: {len(triples)}")
self.knowledge_base_id = knowledge_base_id
self.userid = userid
config = getConfig()
self.neo4j_uri = config['neo4j']['uri']
self.neo4j_user = config['neo4j']['user']
self.neo4j_password = config['neo4j']['password']
self.g = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
info(f"开始构建知识图谱document_id: {self.document_id}, knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid}, 三元组数量: {len(triples)}")
def _normalize_label(self, entity_type: str) -> str:
"""规范化实体类型为 Neo4j 标签"""
@ -25,28 +29,29 @@ class KnowledgeGraph:
return label or 'Entity'
def _clean_relation(self, relation: str) -> Tuple[str, str]:
"""清洗关系,返回 (rel_type, rel_name)"""
"""清洗关系,返回 (rel_type, rel_name),确保 rel_type 合法"""
relation = relation.strip()
if not relation:
return 'RELATED_TO', '相关'
if relation.startswith('<') and relation.endswith('>'):
cleaned_relation = relation[1:-1]
rel_name = cleaned_relation
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper()
else:
rel_name = relation
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper()
if 'instance of' in relation.lower():
rel_type = 'INSTANCE_OF'
rel_name = '实例'
elif 'subclass of' in relation.lower():
rel_type = 'SUBCLASS_OF'
rel_name = '子类'
elif 'part of' in relation.lower():
rel_type = 'PART_OF'
rel_name = '部分'
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})")
return rel_type, rel_name
cleaned_relation = re.sub(r'[^\w\s]', '', relation).strip()
if not cleaned_relation:
return 'RELATED_TO', '相关'
if 'instance of' in relation.lower():
return 'INSTANCE_OF', '实例'
elif 'subclass of' in relation.lower():
return 'SUBCLASS_OF', '子类'
elif 'part of' in relation.lower():
return 'PART_OF', '部分'
rel_type = re.sub(r'\s+', '_', cleaned_relation).upper()
if rel_type and rel_type[0].isdigit():
rel_type = f'REL_{rel_type}'
if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', rel_type):
debug(f"非法关系类型 '{rel_type}',替换为 'RELATED_TO'")
return 'RELATED_TO', relation
return rel_type, relation
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
"""从三元组列表中读取节点和关系"""
@ -57,14 +62,14 @@ class KnowledgeGraph:
try:
for triple in self.triples:
if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']):
logger.warning(f"无效三元组: {triple}")
debug(f"无效三元组: {triple}")
continue
head, relation, tail, head_type, tail_type = (
triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type']
)
head_label = self._normalize_label(head_type)
tail_label = self._normalize_label(tail_type)
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
if head_label not in nodes_by_label:
nodes_by_label[head_label] = set()
@ -92,33 +97,44 @@ class KnowledgeGraph:
'tail_type': tail_type
})
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())}")
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())}")
info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())}")
info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())}")
return nodes_by_label, relations_by_type, triples
except Exception as e:
logger.error(f"读取三元组失败: {str(e)}")
error(f"读取三元组失败: {str(e)}")
raise RuntimeError(f"读取三元组失败: {str(e)}")
def create_node(self, label: str, nodes: Set[str]):
"""创建节点,包含 document_id 属性"""
"""创建节点,包含 document_id、knowledge_base_id 和 userid 属性"""
count = 0
for node_name in nodes:
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n"
query = (
f"MATCH (n:{label} {{name: $name, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}) RETURN n"
)
try:
if self.g.run(query).data():
if self.g.run(query, name=node_name, doc_id=self.document_id,
kb_id=self.knowledge_base_id, userid=self.userid).data():
continue
node = Node(label, name=node_name, document_id=self.document_id)
node = Node(
label,
name=node_name,
document_id=self.document_id,
knowledge_base_id=self.knowledge_base_id,
userid=self.userid
)
self.g.create(node)
count += 1
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})")
debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id}, "
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
except Exception as e:
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
logger.info(f"创建 {label} 节点: {count}/{len(nodes)}")
error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
info(f"创建 {label} 节点: {count}/{len(nodes)}")
return count
def create_relationship(self, rel_type: str, relations: List[Dict]):
"""创建关系"""
"""创建关系,包含 document_id、knowledge_base_id 和 userid 属性"""
count = 0
total = len(relations)
seen_edges = set()
@ -132,17 +148,23 @@ class KnowledgeGraph:
seen_edges.add(edge_key)
query = (
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), "
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) "
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)"
f"MATCH (p:{head_label} {{name: $head, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}), "
f"(q:{tail_label} {{name: $tail, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}) "
f"CREATE (p)-[r:{rel_type} {{name: $rel_name, document_id: $doc_id, "
f"knowledge_base_id: $kb_id, userid: $userid}}]->(q)"
)
try:
self.g.run(query)
self.g.run(query, head=head, tail=tail, rel_name=rel_name,
doc_id=self.document_id, kb_id=self.knowledge_base_id,
userid=self.userid)
count += 1
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})")
debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id}, "
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
except Exception as e:
logger.error(f"创建关系失败: {query}, 错误: {str(e)}")
logger.info(f"创建 {rel_type} 关系: {count}/{total}")
error(f"创建关系失败: {query}, 错误: {str(e)}")
info(f"创建 {rel_type} 关系: {count}/{total}")
return count
def create_graphnodes(self):
@ -151,7 +173,7 @@ class KnowledgeGraph:
total = 0
for label, nodes in nodes_by_label.items():
total += self.create_node(label, nodes)
logger.info(f"总计创建节点: {total}")
info(f"总计创建节点: {total}")
return total
def create_graphrels(self):
@ -160,15 +182,16 @@ class KnowledgeGraph:
total = 0
for rel_type, relations in relations_by_type.items():
total += self.create_relationship(rel_type, relations)
logger.info(f"总计创建关系: {total}")
info(f"总计创建关系: {total}")
return total
def export_data(self):
"""导出节点到文件,包含 document_id"""
"""导出节点到文件,包含 document_id、knowledge_base_id 和 userid"""
nodes_by_label, _, _ = self.read_nodes()
os.makedirs('dict', exist_ok=True)
for label, nodes in nodes_by_label.items():
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes)))
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)}")
f.write('\n'.join(f"{name}\t{self.document_id}\t{self.knowledge_base_id}\t{self.userid}"
for name in sorted(nodes)))
info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)}")
return

View File

@ -1,11 +1,9 @@
# Requires ltp>=0.2.0
from ltp import LTP
from typing import List
import logging
from appPublic.log import debug, info, error
from appPublic.worker import awaitify
from llmengine.base_entity import BaseLtp, ltp_register
logger = logging.getLogger(__name__)
import asyncio
class LtpEntity(BaseLtp):
def __init__(self, model_id):
@ -14,7 +12,7 @@ class LtpEntity(BaseLtp):
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
def extract_entities(self, query: str) -> List[str]:
async def extract_entities(self, query: str) -> List[str]:
"""
从查询文本中抽取实体包括
- LTP NER 识别的实体所有类型
@ -26,7 +24,18 @@ class LtpEntity(BaseLtp):
if not query:
raise ValueError("查询文本不能为空")
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
# 定义同步 pipeline 函数,正确传递 tasks 参数
def sync_pipeline(query, tasks):
return self.ltp.pipeline([query], tasks=tasks)
# 使用 run_in_executor 运行同步 pipeline
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: sync_pipeline(query, ["cws", "pos", "ner"])
)
# 解析结果
words = result.cws[0]
pos_list = result.pos[0]
ner = result.ner[0]
@ -34,7 +43,7 @@ class LtpEntity(BaseLtp):
entities = []
subword_set = set()
logger.debug(f"NER 结果: {ner}")
debug(f"NER 结果: {ner}")
for entity_type, entity, start, end in ner:
entities.append(entity)
@ -49,13 +58,13 @@ class LtpEntity(BaseLtp):
if combined:
entities.append(combined)
subword_set.update(combined_words)
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
debug(f"合并连续名词: {combined}, 子词: {combined_words}")
combined = ""
combined_words = []
else:
combined = ""
combined_words = []
logger.debug(f"连续名词子词集合: {subword_set}")
debug(f"连续名词子词集合: {subword_set}")
for word, pos in zip(words, pos_list):
if pos == 'n' and word not in subword_set:
@ -66,11 +75,11 @@ class LtpEntity(BaseLtp):
entities.append(word)
unique_entities = list(dict.fromkeys(entities))
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
logger.error(f"实体抽取失败: {str(e)}")
return []
error(f"实体抽取失败: {str(e)}")
raise # 抛出异常以便调试,而不是返回空列表
ltp_register('LTP', LtpEntity)

View File

@ -16,6 +16,7 @@ from llmengine.kgc import KnowledgeGraph
import numpy as np
from py2neo import Graph
from scipy.spatial.distance import cosine
import time
# 嵌入缓存
EMBED_CACHE = {}
@ -182,7 +183,7 @@ class MilvusConnection:
if not userid:
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
"document_id": "", "status_code": 400}
return await self.list_user_files(userid)
return await self._list_user_files(userid)
else:
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
"document_id": "", "status_code": 400}
@ -814,20 +815,24 @@ class MilvusConnection:
error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
return []
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]:
"""匹配查询实体与 Neo4j 中的三元组"""
start_time = time.time() # 记录开始时间
matched_triplets = []
ENTITY_SIMILARITY_THRESHOLD = 0.8
try:
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
neo4j_connect_time = time.time() - start_time
debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f}")
matched_names = set()
entity_match_start = time.time()
for entity in query_entities:
normalized_entity = entity.lower().strip()
query = """
MATCH (n {document_id: $document_id})
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE toLower(n.name) CONTAINS $entity
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
@ -835,24 +840,27 @@ class MilvusConnection:
LIMIT 100
"""
try:
results = graph.run(query, document_id=document_id, entity=normalized_entity).data()
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data()
for record in results:
matched_names.add(record['n.name'])
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
except Exception as e:
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
continue
entity_match_time = time.time() - entity_match_start
debug(f"实体匹配耗时: {entity_match_time:.3f}")
triplets = []
if matched_names:
triplet_query_start = time.time()
query = """
MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id})
MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id})
WHERE h.name IN $matched_names OR t.name IN $matched_names
RETURN h.name AS head, r.name AS type, t.name AS tail
LIMIT 100
"""
try:
results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data()
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data()
seen = set()
for record in results:
head, type_, tail = record['head'], record['type'], record['tail']
@ -866,22 +874,28 @@ class MilvusConnection:
'head_type': '',
'tail_type': ''
})
debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}")
debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}")
except Exception as e:
error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}")
error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}\n{exception()}")
return []
triplet_query_time = time.time() - triplet_query_start
debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f}")
if not triplets:
debug(f"文档 document_id={document_id} 无匹配三元组")
debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组")
return []
embedding_start = time.time()
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
embeddings = await self._get_embeddings(texts_to_embed)
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails")
embedding_time = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {embedding_time:.3f}")
similarity_start = time.time()
for entity in query_entities:
entity_vec = entity_vectors[entity]
for d_triplet in triplets:
@ -894,6 +908,8 @@ class MilvusConnection:
matched_triplets.append(d_triplet)
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
similarity_time = time.time() - similarity_start
debug(f"相似度计算耗时: {similarity_time:.3f}")
unique_matched = []
seen = set()
@ -903,6 +919,8 @@ class MilvusConnection:
seen.add(identifier)
unique_matched.append(t)
total_time = time.time() - start_time
debug(f"_match_triplets 总耗时: {total_time:.3f}")
info(f"找到 {len(unique_matched)} 个匹配的三元组")
return unique_matched
@ -957,9 +975,11 @@ class MilvusConnection:
return results
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
start_time = time.time() # 记录开始时间
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {} # 记录各步骤耗时
try:
info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
@ -975,46 +995,38 @@ class MilvusConnection:
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return []
return {"results": [], "timing": timing_stats}
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
timing_stats["collection_load"] = time.time() - start_time
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return []
return {"results": [], "timing": timing_stats}
entity_extract_start = time.time()
query_entities = await self._extract_entities(query)
debug(f"提取实体: {query_entities}")
timing_stats["entity_extraction"] = time.time() - entity_extract_start
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f}")
documents = []
all_triplets = []
triplet_match_start = time.time()
for kb_id in knowledge_base_ids:
debug(f"处理知识库: {kb_id}")
matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id)
debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)}")
all_triplets.extend(matched_triplets)
timing_stats["triplet_matching"] = time.time() - triplet_match_start
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f}")
results = collection.query(
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'",
output_fields=["document_id", "filename", "knowledge_base_id"],
limit=100 # 查询足够多的文档以支持后续过滤
)
if not results:
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
continue
documents.extend(results)
for doc in results:
document_id = doc["document_id"]
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)}")
all_triplets.extend(matched_triplets)
if not documents:
debug("未找到任何有效文档")
return []
info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
if not all_triplets:
debug("未找到任何匹配的三元组")
return {"results": [], "timing": timing_stats}
triplet_text_start = time.time()
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
@ -1029,13 +1041,18 @@ class MilvusConnection:
combined_text += " [三元组] " + "; ".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f}")
embedding_start = time.time()
embeddings = await self._get_embeddings([combined_text])
query_vector = embeddings[0]
debug(f"拼接文本向量维度: {len(query_vector)}")
timing_stats["embedding_generation"] = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f}")
search_start = time.time()
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_expr})"
debug(f"搜索表达式: {expr}")
@ -1053,7 +1070,9 @@ class MilvusConnection:
)
except Exception as e:
error(f"向量搜索失败: {str(e)}\n{exception()}")
return []
return {"results": [], "timing": timing_stats}
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
search_results = []
for hits in results:
@ -1078,31 +1097,40 @@ class MilvusConnection:
unique_results = []
seen_texts = set()
dedup_start = time.time()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts:
unique_results.append(result)
seen_texts.add(result['text'])
timing_stats["deduplication"] = time.time() - dedup_start
debug(f"去重耗时: {timing_stats['deduplication']:.3f}")
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit
unique_results = await self._rerank_results(combined_text, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"融合搜索完成,返回 {len(unique_results)} 条结果")
return unique_results[:limit]
timing_stats["total_time"] = time.time() - start_time
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"融合搜索失败: {str(e)}\n{exception()}")
return []
return {"results": [], "timing": timing_stats}
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
start_time = time.time() # 记录开始时间
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
timing_stats = {} # 记录各步骤耗时
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
@ -1133,22 +1161,27 @@ class MilvusConnection:
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return []
return {"results": [], "timing": timing_stats}
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
timing_stats["collection_load"] = time.time() - start_time
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
raise RuntimeError(f"加载集合失败: {str(e)}")
return {"results": [], "timing": timing_stats}
embedding_start = time.time()
embeddings = await self._get_embeddings([query])
query_vector = embeddings[0]
debug(f"查询向量维度: {len(query_vector)}")
timing_stats["embedding_generation"] = time.time() - embedding_start
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f}")
search_start = time.time()
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_id_expr})"
debug(f"搜索表达式: {expr}")
@ -1166,7 +1199,9 @@ class MilvusConnection:
)
except Exception as e:
error(f"搜索失败: {str(e)}\n{exception()}")
raise RuntimeError(f"搜索失败: {str(e)}")
return {"results": [], "timing": timing_stats}
timing_stats["vector_search"] = time.time() - search_start
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f}")
search_results = []
for hits in results:
@ -1189,96 +1224,100 @@ class MilvusConnection:
debug(
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
dedup_start = time.time()
unique_results = []
seen_texts = set()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts:
unique_results.append(result)
seen_texts.add(result['text'])
timing_stats["deduplication"] = time.time() - dedup_start
debug(f"去重耗时: {timing_stats['deduplication']:.3f}")
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results:
rerank_start = time.time()
debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
timing_stats["reranking"] = time.time() - rerank_start
debug(f"重排序耗时: {timing_stats['reranking']:.3f}")
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果")
return unique_results[:limit]
timing_stats["total_time"] = time.time() - start_time
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f}")
return {"results": unique_results[:limit], "timing": timing_stats}
except Exception as e:
error(f"纯向量搜索失败: {str(e)}\n{exception()}")
return []
return {"results": [], "timing": timing_stats}
async def list_user_files(self, userid: str) -> List[Dict]:
"""根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询"""
async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
"""列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(f"开始查询用户文件列表: userid={userid}")
info(f"列出用户文件: userid={userid}, db_type={db_type}")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if len(userid) > 100:
raise ValueError("userid 长度超出限制")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
collections = utility.list_collections()
collections = [c for c in collections if c.startswith("ragdb")]
if not collections:
debug("未找到任何 ragdb 开头的集合")
return []
debug(f"找到集合: {collections}")
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return {}
file_list = []
seen_files = set()
for collection_name in collections:
db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else ""
debug(f"处理集合: {collection_name}, db_type={db_type}")
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return {}
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
continue
expr = f"userid == '{userid}'"
debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["filename", "file_path", "upload_time", "file_type"],
limit=1000
)
debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块")
except Exception as e:
error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}")
continue
try:
results = collection.query(
expr=expr,
output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"],
limit=1000
)
except Exception as e:
error(f"查询用户文件失败: {str(e)}\n{exception()}")
return {}
for result in results:
filename = result.get("filename")
file_path = result.get("file_path")
upload_time = result.get("upload_time")
file_type = result.get("file_type")
if (filename, file_path) not in seen_files:
seen_files.add((filename, file_path))
file_list.append({
"filename": filename,
"file_path": file_path,
"db_type": db_type,
"upload_time": upload_time,
"file_type": file_type
})
debug(
f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}")
files_by_kb = {}
seen_document_ids = set()
for result in results:
document_id = result.get("document_id")
kb_id = result.get("knowledge_base_id")
if document_id not in seen_document_ids:
seen_document_ids.add(document_id)
file_info = {
"document_id": document_id,
"filename": result.get("filename"),
"file_path": result.get("file_path"),
"upload_time": result.get("upload_time"),
"file_type": result.get("file_type"),
"knowledge_base_id": kb_id
}
if kb_id not in files_by_kb:
files_by_kb[kb_id] = []
files_by_kb[kb_id].append(file_info)
debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}")
info(f"返回 {len(file_list)} 个文件")
return sorted(file_list, key=lambda x: x["upload_time"], reverse=True)
info(f"找到 {len(seen_document_ids)} 个文件userid={userid}, 知识库数量={len(files_by_kb)}")
return files_by_kb
except Exception as e:
error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}")
return []
error(f"列出用户文件失败: {str(e)}\n{exception()}")
return {}
connection_register('Milvus', MilvusConnection)
info("MilvusConnection registered")

View File

@ -1,25 +1,25 @@
import os
import torch
import re
import traceback
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
from appPublic.log import debug, error, warning, info
from appPublic.worker import awaitify
from base_triple import BaseTripleExtractor, llm_register
logger = logging.getLogger(__name__)
class MRebelTripleExtractor(BaseTripleExtractor):
def __init__(self, model_path: str):
super().__init__(model_path)
try:
logger.debug(f"Loading tokenizer from {model_path}")
debug(f"Loading tokenizer from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.debug(f"Loading model from {model_path}")
debug(f"Loading model from {model_path}")
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.device = self.use_mps_if_possible()
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
except Exception as e:
logger.error(f"Failed to load mREBEL model: {str(e)}")
error(f"Failed to load mREBEL model: {str(e)}")
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
self.gen_kwargs = {
@ -47,13 +47,13 @@ class MRebelTripleExtractor(BaseTripleExtractor):
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
logger.debug(f"Text split into: {len(chunks)} chunks")
debug(f"Text split into: {len(chunks)} chunks")
return chunks
def extract_triplets_typed(self, text: str) -> list:
"""Parse mREBEL generated text for triplets."""
triplets = []
logger.debug(f"Raw generated text: {text}")
debug(f"Raw generated text: {text}")
tokens = []
in_tag = False
@ -77,7 +77,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
tokens = [t for t in tokens if t not in special_tokens and t]
logger.debug(f"Processed tokens: {tokens}")
debug(f"Processed tokens: {tokens}")
i = 0
while i < len(tokens):
@ -96,7 +96,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
'tail': entity2.strip(),
'tail_type': type2
})
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
i += 6
else:
i += 1
@ -110,11 +110,11 @@ class MRebelTripleExtractor(BaseTripleExtractor):
raise ValueError("Text cannot be empty")
text_chunks = self.split_document(text, max_chunk_size=150)
logger.debug(f"Text split into {len(text_chunks)} chunks")
debug(f"Text split into {len(text_chunks)} chunks")
all_triplets = []
for i, chunk in enumerate(text_chunks):
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
model_inputs = self.tokenizer(
chunk,
@ -128,17 +128,17 @@ class MRebelTripleExtractor(BaseTripleExtractor):
generated_tokens = self.model.generate(
model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
**self.gen_kwargs,
**self.gen_kwargs
)
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
for idx, sentence in enumerate(decoded_preds):
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
debug(f"Chunk {i + 1} generated text: {sentence}")
triplets = self.extract_triplets_typed(sentence)
if triplets:
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
all_triplets.extend(triplets)
except Exception as e:
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
warning(f"Error processing chunk {i + 1}: {str(e)}")
continue
unique_triplets = []
@ -149,13 +149,12 @@ class MRebelTripleExtractor(BaseTripleExtractor):
seen.add(identifier)
unique_triplets.append(t)
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
info(f"Extracted {len(unique_triplets)} unique triplets")
return unique_triplets
except Exception as e:
logger.error(f"Failed to extract triplets: {str(e)}")
import traceback
logger.debug(traceback.format_exc())
error(f"Failed to extract triplets: {str(e)}")
debug(f"Traceback: {traceback.format_exc()}")
return []
llm_register("mrebel-large", MRebelTripleExtractor)