增加数据库服务化基于数据库+知识图谱召回文本块、列出用户不同知识库中的文档功能
This commit is contained in:
parent
1c783461bd
commit
564bddfcde
@ -80,7 +80,9 @@ data: {
|
||||
"use_rerank": true
|
||||
}
|
||||
response:
|
||||
- Success: HTTP 200, [
|
||||
- Success: HTTP 200, {
|
||||
"status": "success",
|
||||
"results": [
|
||||
{
|
||||
"text": "<完整文本内容>",
|
||||
"distance": 0.95,
|
||||
@ -96,9 +98,25 @@ response:
|
||||
}
|
||||
},
|
||||
...
|
||||
]
|
||||
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"}
|
||||
|
||||
],
|
||||
"timing": {
|
||||
"collection_load": <float>, // 集合加载耗时(秒)
|
||||
"entity_extraction": <float>, // 实体提取耗时(秒)
|
||||
"triplet_matching": <float>, // 三元组匹配耗时(秒)
|
||||
"triplet_text_combine": <float>, // 拼接三元组文本耗时(秒)
|
||||
"embedding_generation": <float>, // 嵌入向量生成耗时(秒)
|
||||
"vector_search": <float>, // 向量搜索耗时(秒)
|
||||
"deduplication": <float>, // 去重耗时(秒)
|
||||
"reranking": <float>, // 重排序耗时(秒,若 use_rerank=true)
|
||||
"total_time": <float> // 总耗时(秒)
|
||||
},
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
- Error: HTTP 400, {
|
||||
"status": "error",
|
||||
"message": "<error message>",
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
6. Search Query Endpoint:
|
||||
path: /v1/searchquery
|
||||
method: POST
|
||||
@ -113,7 +131,9 @@ data: {
|
||||
"use_rerank": true
|
||||
}
|
||||
response:
|
||||
- Success: HTTP 200, [
|
||||
- Success: HTTP 200, {
|
||||
"status": "success",
|
||||
"results": [
|
||||
{
|
||||
"text": "<完整文本内容>",
|
||||
"distance": 0.95,
|
||||
@ -129,29 +149,65 @@ response:
|
||||
}
|
||||
},
|
||||
...
|
||||
]
|
||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
||||
],
|
||||
"timing": {
|
||||
"collection_load": <float>, // 集合加载耗时(秒)
|
||||
"embedding_generation": <float>, // 嵌入向量生成耗时(秒)
|
||||
"vector_search": <float>, // 向量搜索耗时(秒)
|
||||
"deduplication": <float>, // 去重耗时(秒)
|
||||
"reranking": <float>, // 重排序耗时(秒,若 use_rerank=true)
|
||||
"total_time": <float> // 总耗时(秒)
|
||||
},
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
- Error: HTTP 400, {
|
||||
"status": "error",
|
||||
"message": "<error message>",
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
|
||||
7. List User Files Endpoint:
|
||||
path: /v1/listuserfiles
|
||||
method: POST
|
||||
headers: {"Content-Type": "application/json"}
|
||||
data: {
|
||||
"userid": "testuser2"
|
||||
"userid": "user1",
|
||||
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||
}
|
||||
response:
|
||||
- Success: HTTP 200, [
|
||||
- Success: HTTP 200, {
|
||||
"status": "success",
|
||||
"files_by_knowledge_base": {
|
||||
"kb123": [
|
||||
{
|
||||
"filename": "file.txt",
|
||||
"file_path": "/path/to/file.txt",
|
||||
"db_type": "textdb",
|
||||
"document_id": "<uuid>",
|
||||
"filename": "file1.txt",
|
||||
"file_path": "/path/to/file1.txt",
|
||||
"upload_time": "<iso_timestamp>",
|
||||
"file_type": "txt"
|
||||
"file_type": "txt",
|
||||
"knowledge_base_id": "kb123"
|
||||
},
|
||||
...
|
||||
],
|
||||
"kb456": [
|
||||
{
|
||||
"document_id": "<uuid>",
|
||||
"filename": "file2.pdf",
|
||||
"file_path": "/path/to/file2.pdf",
|
||||
"upload_time": "<iso_timestamp>",
|
||||
"file_type": "pdf",
|
||||
"knowledge_base_id": "kb456"
|
||||
},
|
||||
...
|
||||
]
|
||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
||||
|
||||
},
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
- Error: HTTP 400, {
|
||||
"status": "error",
|
||||
"message": "<error message>",
|
||||
"collection_name": "ragdb" or "ragdb_textdb"
|
||||
}
|
||||
8. Connection Endpoint (for compatibility):
|
||||
path: /v1/connection
|
||||
method: POST
|
||||
@ -378,7 +434,13 @@ async def fused_search_query(request, params_kw, *params, **kw):
|
||||
"use_rerank": use_rerank
|
||||
})
|
||||
debug(f'{result=}')
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
response = {
|
||||
"status": "success",
|
||||
"results": result.get("results", []),
|
||||
"timing": result.get("timing", {}),
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'融合搜索失败: {str(e)}')
|
||||
return web.json_response({
|
||||
@ -396,7 +458,7 @@ async def search_query(request, params_kw, *params, **kw):
|
||||
db_type = params_kw.get('db_type', '')
|
||||
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||
limit = params_kw.get('limit')
|
||||
offset = params_kw.get('offset')
|
||||
offset = params_kw.get('offset', 0)
|
||||
use_rerank = params_kw.get('use_rerank', True)
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
@ -417,7 +479,13 @@ async def search_query(request, params_kw, *params, **kw):
|
||||
"use_rerank": use_rerank
|
||||
})
|
||||
debug(f'{result=}')
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
response = {
|
||||
"status": "success",
|
||||
"results": result.get("results", []),
|
||||
"timing": result.get("timing", {}),
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'纯向量搜索失败: {str(e)}')
|
||||
return web.json_response({
|
||||
@ -431,23 +499,33 @@ async def list_user_files(request, params_kw, *params, **kw):
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
userid = params_kw.get('userid')
|
||||
db_type = params_kw.get('db_type', '')
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
if not userid:
|
||||
debug(f'userid 未提供')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": "userid 参数未提供"
|
||||
"message": "userid 未提供",
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
try:
|
||||
result = await engine.handle_connection("list_user_files", {
|
||||
"userid": userid
|
||||
"userid": userid,
|
||||
"db_type": db_type
|
||||
})
|
||||
debug(f'{result=}')
|
||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
response = {
|
||||
"status": "success",
|
||||
"files_by_knowledge_base": result,
|
||||
"collection_name": collection_name
|
||||
}
|
||||
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error(f'查询用户文件列表失败: {str(e)}')
|
||||
error(f'列出用户文件失败: {str(e)}')
|
||||
return web.json_response({
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
"message": str(e),
|
||||
"collection_name": collection_name
|
||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||
|
||||
async def handle_connection(request, params_kw, *params, **kw):
|
||||
|
@ -45,6 +45,7 @@ def init():
|
||||
rf = RegisterFunction()
|
||||
rf.register('entities', entities)
|
||||
rf.register('docs', docs)
|
||||
debug("注册路由: entities, docs")
|
||||
|
||||
async def docs(request, params_kw, *params, **kw):
|
||||
return helptext
|
||||
@ -53,12 +54,11 @@ async def entities(request, params_kw, *params, **kw):
|
||||
debug(f'{params_kw.query=}')
|
||||
se = ServerEnv()
|
||||
engine = se.engine
|
||||
f = awaitify(engine.extract_entities)
|
||||
query = params_kw.query
|
||||
if query is None:
|
||||
e = exception(f'query is None')
|
||||
raise e
|
||||
entities = await f(query)
|
||||
entities = await engine.extract_entities(query)
|
||||
debug(f'{entities=}, type(entities)')
|
||||
return {
|
||||
"object": "list",
|
||||
|
119
llmengine/kgc.py
119
llmengine/kgc.py
@ -1,19 +1,23 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from py2neo import Graph, Node, Relationship
|
||||
from typing import Set, List, Dict, Tuple
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from appPublic.log import debug, error, info, exception
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class KnowledgeGraph:
|
||||
def __init__(self, triples: List[Dict], document_id: str):
|
||||
def __init__(self, triples: List[Dict], document_id: str, knowledge_base_id: str, userid: str):
|
||||
self.triples = triples
|
||||
self.document_id = document_id
|
||||
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh'))
|
||||
logger.info(f"开始构建知识图谱,document_id: {self.document_id}, 三元组数量: {len(triples)}")
|
||||
self.knowledge_base_id = knowledge_base_id
|
||||
self.userid = userid
|
||||
config = getConfig()
|
||||
self.neo4j_uri = config['neo4j']['uri']
|
||||
self.neo4j_user = config['neo4j']['user']
|
||||
self.neo4j_password = config['neo4j']['password']
|
||||
self.g = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
|
||||
info(f"开始构建知识图谱,document_id: {self.document_id}, knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid}, 三元组数量: {len(triples)}")
|
||||
|
||||
def _normalize_label(self, entity_type: str) -> str:
|
||||
"""规范化实体类型为 Neo4j 标签"""
|
||||
@ -25,28 +29,29 @@ class KnowledgeGraph:
|
||||
return label or 'Entity'
|
||||
|
||||
def _clean_relation(self, relation: str) -> Tuple[str, str]:
|
||||
"""清洗关系,返回 (rel_type, rel_name)"""
|
||||
"""清洗关系,返回 (rel_type, rel_name),确保 rel_type 合法"""
|
||||
relation = relation.strip()
|
||||
if not relation:
|
||||
return 'RELATED_TO', '相关'
|
||||
if relation.startswith('<') and relation.endswith('>'):
|
||||
cleaned_relation = relation[1:-1]
|
||||
rel_name = cleaned_relation
|
||||
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper()
|
||||
else:
|
||||
rel_name = relation
|
||||
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper()
|
||||
|
||||
cleaned_relation = re.sub(r'[^\w\s]', '', relation).strip()
|
||||
if not cleaned_relation:
|
||||
return 'RELATED_TO', '相关'
|
||||
|
||||
if 'instance of' in relation.lower():
|
||||
rel_type = 'INSTANCE_OF'
|
||||
rel_name = '实例'
|
||||
return 'INSTANCE_OF', '实例'
|
||||
elif 'subclass of' in relation.lower():
|
||||
rel_type = 'SUBCLASS_OF'
|
||||
rel_name = '子类'
|
||||
return 'SUBCLASS_OF', '子类'
|
||||
elif 'part of' in relation.lower():
|
||||
rel_type = 'PART_OF'
|
||||
rel_name = '部分'
|
||||
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})")
|
||||
return rel_type, rel_name
|
||||
return 'PART_OF', '部分'
|
||||
|
||||
rel_type = re.sub(r'\s+', '_', cleaned_relation).upper()
|
||||
if rel_type and rel_type[0].isdigit():
|
||||
rel_type = f'REL_{rel_type}'
|
||||
if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', rel_type):
|
||||
debug(f"非法关系类型 '{rel_type}',替换为 'RELATED_TO'")
|
||||
return 'RELATED_TO', relation
|
||||
return rel_type, relation
|
||||
|
||||
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
|
||||
"""从三元组列表中读取节点和关系"""
|
||||
@ -57,14 +62,14 @@ class KnowledgeGraph:
|
||||
try:
|
||||
for triple in self.triples:
|
||||
if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']):
|
||||
logger.warning(f"无效三元组: {triple}")
|
||||
debug(f"无效三元组: {triple}")
|
||||
continue
|
||||
head, relation, tail, head_type, tail_type = (
|
||||
triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type']
|
||||
)
|
||||
head_label = self._normalize_label(head_type)
|
||||
tail_label = self._normalize_label(tail_type)
|
||||
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
|
||||
debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
|
||||
|
||||
if head_label not in nodes_by_label:
|
||||
nodes_by_label[head_label] = set()
|
||||
@ -92,33 +97,44 @@ class KnowledgeGraph:
|
||||
'tail_type': tail_type
|
||||
})
|
||||
|
||||
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个")
|
||||
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条")
|
||||
info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个")
|
||||
info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条")
|
||||
return nodes_by_label, relations_by_type, triples
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取三元组失败: {str(e)}")
|
||||
error(f"读取三元组失败: {str(e)}")
|
||||
raise RuntimeError(f"读取三元组失败: {str(e)}")
|
||||
|
||||
def create_node(self, label: str, nodes: Set[str]):
|
||||
"""创建节点,包含 document_id 属性"""
|
||||
"""创建节点,包含 document_id、knowledge_base_id 和 userid 属性"""
|
||||
count = 0
|
||||
for node_name in nodes:
|
||||
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n"
|
||||
query = (
|
||||
f"MATCH (n:{label} {{name: $name, document_id: $doc_id, "
|
||||
f"knowledge_base_id: $kb_id, userid: $userid}}) RETURN n"
|
||||
)
|
||||
try:
|
||||
if self.g.run(query).data():
|
||||
if self.g.run(query, name=node_name, doc_id=self.document_id,
|
||||
kb_id=self.knowledge_base_id, userid=self.userid).data():
|
||||
continue
|
||||
node = Node(label, name=node_name, document_id=self.document_id)
|
||||
node = Node(
|
||||
label,
|
||||
name=node_name,
|
||||
document_id=self.document_id,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
userid=self.userid
|
||||
)
|
||||
self.g.create(node)
|
||||
count += 1
|
||||
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})")
|
||||
debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id}, "
|
||||
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
|
||||
except Exception as e:
|
||||
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
|
||||
logger.info(f"创建 {label} 节点: {count}/{len(nodes)} 个")
|
||||
error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
|
||||
info(f"创建 {label} 节点: {count}/{len(nodes)} 个")
|
||||
return count
|
||||
|
||||
def create_relationship(self, rel_type: str, relations: List[Dict]):
|
||||
"""创建关系"""
|
||||
"""创建关系,包含 document_id、knowledge_base_id 和 userid 属性"""
|
||||
count = 0
|
||||
total = len(relations)
|
||||
seen_edges = set()
|
||||
@ -132,17 +148,23 @@ class KnowledgeGraph:
|
||||
seen_edges.add(edge_key)
|
||||
|
||||
query = (
|
||||
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), "
|
||||
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) "
|
||||
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)"
|
||||
f"MATCH (p:{head_label} {{name: $head, document_id: $doc_id, "
|
||||
f"knowledge_base_id: $kb_id, userid: $userid}}), "
|
||||
f"(q:{tail_label} {{name: $tail, document_id: $doc_id, "
|
||||
f"knowledge_base_id: $kb_id, userid: $userid}}) "
|
||||
f"CREATE (p)-[r:{rel_type} {{name: $rel_name, document_id: $doc_id, "
|
||||
f"knowledge_base_id: $kb_id, userid: $userid}}]->(q)"
|
||||
)
|
||||
try:
|
||||
self.g.run(query)
|
||||
self.g.run(query, head=head, tail=tail, rel_name=rel_name,
|
||||
doc_id=self.document_id, kb_id=self.knowledge_base_id,
|
||||
userid=self.userid)
|
||||
count += 1
|
||||
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})")
|
||||
debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id}, "
|
||||
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
|
||||
except Exception as e:
|
||||
logger.error(f"创建关系失败: {query}, 错误: {str(e)}")
|
||||
logger.info(f"创建 {rel_type} 关系: {count}/{total} 条")
|
||||
error(f"创建关系失败: {query}, 错误: {str(e)}")
|
||||
info(f"创建 {rel_type} 关系: {count}/{total} 条")
|
||||
return count
|
||||
|
||||
def create_graphnodes(self):
|
||||
@ -151,7 +173,7 @@ class KnowledgeGraph:
|
||||
total = 0
|
||||
for label, nodes in nodes_by_label.items():
|
||||
total += self.create_node(label, nodes)
|
||||
logger.info(f"总计创建节点: {total} 个")
|
||||
info(f"总计创建节点: {total} 个")
|
||||
return total
|
||||
|
||||
def create_graphrels(self):
|
||||
@ -160,15 +182,16 @@ class KnowledgeGraph:
|
||||
total = 0
|
||||
for rel_type, relations in relations_by_type.items():
|
||||
total += self.create_relationship(rel_type, relations)
|
||||
logger.info(f"总计创建关系: {total} 条")
|
||||
info(f"总计创建关系: {total} 条")
|
||||
return total
|
||||
|
||||
def export_data(self):
|
||||
"""导出节点到文件,包含 document_id"""
|
||||
"""导出节点到文件,包含 document_id、knowledge_base_id 和 userid"""
|
||||
nodes_by_label, _, _ = self.read_nodes()
|
||||
os.makedirs('dict', exist_ok=True)
|
||||
for label, nodes in nodes_by_label.items():
|
||||
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes)))
|
||||
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个")
|
||||
f.write('\n'.join(f"{name}\t{self.document_id}\t{self.knowledge_base_id}\t{self.userid}"
|
||||
for name in sorted(nodes)))
|
||||
info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个")
|
||||
return
|
@ -1,11 +1,9 @@
|
||||
# Requires ltp>=0.2.0
|
||||
|
||||
from ltp import LTP
|
||||
from typing import List
|
||||
import logging
|
||||
from appPublic.log import debug, info, error
|
||||
from appPublic.worker import awaitify
|
||||
from llmengine.base_entity import BaseLtp, ltp_register
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import asyncio
|
||||
|
||||
class LtpEntity(BaseLtp):
|
||||
def __init__(self, model_id):
|
||||
@ -14,7 +12,7 @@ class LtpEntity(BaseLtp):
|
||||
self.model_id = model_id
|
||||
self.model_name = model_id.split('/')[-1]
|
||||
|
||||
def extract_entities(self, query: str) -> List[str]:
|
||||
async def extract_entities(self, query: str) -> List[str]:
|
||||
"""
|
||||
从查询文本中抽取实体,包括:
|
||||
- LTP NER 识别的实体(所有类型)。
|
||||
@ -26,7 +24,18 @@ class LtpEntity(BaseLtp):
|
||||
if not query:
|
||||
raise ValueError("查询文本不能为空")
|
||||
|
||||
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
||||
# 定义同步 pipeline 函数,正确传递 tasks 参数
|
||||
def sync_pipeline(query, tasks):
|
||||
return self.ltp.pipeline([query], tasks=tasks)
|
||||
|
||||
# 使用 run_in_executor 运行同步 pipeline
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: sync_pipeline(query, ["cws", "pos", "ner"])
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
words = result.cws[0]
|
||||
pos_list = result.pos[0]
|
||||
ner = result.ner[0]
|
||||
@ -34,7 +43,7 @@ class LtpEntity(BaseLtp):
|
||||
entities = []
|
||||
subword_set = set()
|
||||
|
||||
logger.debug(f"NER 结果: {ner}")
|
||||
debug(f"NER 结果: {ner}")
|
||||
for entity_type, entity, start, end in ner:
|
||||
entities.append(entity)
|
||||
|
||||
@ -49,13 +58,13 @@ class LtpEntity(BaseLtp):
|
||||
if combined:
|
||||
entities.append(combined)
|
||||
subword_set.update(combined_words)
|
||||
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||
debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||
combined = ""
|
||||
combined_words = []
|
||||
else:
|
||||
combined = ""
|
||||
combined_words = []
|
||||
logger.debug(f"连续名词子词集合: {subword_set}")
|
||||
debug(f"连续名词子词集合: {subword_set}")
|
||||
|
||||
for word, pos in zip(words, pos_list):
|
||||
if pos == 'n' and word not in subword_set:
|
||||
@ -66,11 +75,11 @@ class LtpEntity(BaseLtp):
|
||||
entities.append(word)
|
||||
|
||||
unique_entities = list(dict.fromkeys(entities))
|
||||
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||
info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||
return unique_entities
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"实体抽取失败: {str(e)}")
|
||||
return []
|
||||
error(f"实体抽取失败: {str(e)}")
|
||||
raise # 抛出异常以便调试,而不是返回空列表
|
||||
|
||||
ltp_register('LTP', LtpEntity)
|
@ -16,6 +16,7 @@ from llmengine.kgc import KnowledgeGraph
|
||||
import numpy as np
|
||||
from py2neo import Graph
|
||||
from scipy.spatial.distance import cosine
|
||||
import time
|
||||
|
||||
# 嵌入缓存
|
||||
EMBED_CACHE = {}
|
||||
@ -182,7 +183,7 @@ class MilvusConnection:
|
||||
if not userid:
|
||||
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
|
||||
"document_id": "", "status_code": 400}
|
||||
return await self.list_user_files(userid)
|
||||
return await self._list_user_files(userid)
|
||||
else:
|
||||
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
|
||||
"document_id": "", "status_code": 400}
|
||||
@ -814,20 +815,24 @@ class MilvusConnection:
|
||||
error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
|
||||
return []
|
||||
|
||||
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
|
||||
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]:
|
||||
"""匹配查询实体与 Neo4j 中的三元组"""
|
||||
start_time = time.time() # 记录开始时间
|
||||
matched_triplets = []
|
||||
ENTITY_SIMILARITY_THRESHOLD = 0.8
|
||||
|
||||
try:
|
||||
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
|
||||
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
|
||||
neo4j_connect_time = time.time() - start_time
|
||||
debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f} 秒")
|
||||
|
||||
matched_names = set()
|
||||
entity_match_start = time.time()
|
||||
for entity in query_entities:
|
||||
normalized_entity = entity.lower().strip()
|
||||
query = """
|
||||
MATCH (n {document_id: $document_id})
|
||||
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
|
||||
WHERE toLower(n.name) CONTAINS $entity
|
||||
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
|
||||
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
|
||||
@ -835,24 +840,27 @@ class MilvusConnection:
|
||||
LIMIT 100
|
||||
"""
|
||||
try:
|
||||
results = graph.run(query, document_id=document_id, entity=normalized_entity).data()
|
||||
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data()
|
||||
for record in results:
|
||||
matched_names.add(record['n.name'])
|
||||
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
|
||||
except Exception as e:
|
||||
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
|
||||
continue
|
||||
entity_match_time = time.time() - entity_match_start
|
||||
debug(f"实体匹配耗时: {entity_match_time:.3f} 秒")
|
||||
|
||||
triplets = []
|
||||
if matched_names:
|
||||
triplet_query_start = time.time()
|
||||
query = """
|
||||
MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id})
|
||||
MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id})
|
||||
WHERE h.name IN $matched_names OR t.name IN $matched_names
|
||||
RETURN h.name AS head, r.name AS type, t.name AS tail
|
||||
LIMIT 100
|
||||
"""
|
||||
try:
|
||||
results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data()
|
||||
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data()
|
||||
seen = set()
|
||||
for record in results:
|
||||
head, type_, tail = record['head'], record['type'], record['tail']
|
||||
@ -866,22 +874,28 @@ class MilvusConnection:
|
||||
'head_type': '',
|
||||
'tail_type': ''
|
||||
})
|
||||
debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}")
|
||||
debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}")
|
||||
except Exception as e:
|
||||
error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}")
|
||||
error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}\n{exception()}")
|
||||
return []
|
||||
triplet_query_time = time.time() - triplet_query_start
|
||||
debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f} 秒")
|
||||
|
||||
if not triplets:
|
||||
debug(f"文档 document_id={document_id} 无匹配三元组")
|
||||
debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组")
|
||||
return []
|
||||
|
||||
embedding_start = time.time()
|
||||
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
|
||||
embeddings = await self._get_embeddings(texts_to_embed)
|
||||
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
|
||||
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
|
||||
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
|
||||
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)")
|
||||
embedding_time = time.time() - embedding_start
|
||||
debug(f"嵌入向量生成耗时: {embedding_time:.3f} 秒")
|
||||
|
||||
similarity_start = time.time()
|
||||
for entity in query_entities:
|
||||
entity_vec = entity_vectors[entity]
|
||||
for d_triplet in triplets:
|
||||
@ -894,6 +908,8 @@ class MilvusConnection:
|
||||
matched_triplets.append(d_triplet)
|
||||
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
|
||||
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
|
||||
similarity_time = time.time() - similarity_start
|
||||
debug(f"相似度计算耗时: {similarity_time:.3f} 秒")
|
||||
|
||||
unique_matched = []
|
||||
seen = set()
|
||||
@ -903,6 +919,8 @@ class MilvusConnection:
|
||||
seen.add(identifier)
|
||||
unique_matched.append(t)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
debug(f"_match_triplets 总耗时: {total_time:.3f} 秒")
|
||||
info(f"找到 {len(unique_matched)} 个匹配的三元组")
|
||||
return unique_matched
|
||||
|
||||
@ -957,9 +975,11 @@ class MilvusConnection:
|
||||
return results
|
||||
|
||||
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
|
||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
||||
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
|
||||
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
|
||||
start_time = time.time() # 记录开始时间
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
timing_stats = {} # 记录各步骤耗时
|
||||
try:
|
||||
info(
|
||||
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||
@ -975,46 +995,38 @@ class MilvusConnection:
|
||||
|
||||
if not utility.has_collection(collection_name):
|
||||
debug(f"集合 {collection_name} 不存在")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
debug(f"加载集合: {collection_name}")
|
||||
timing_stats["collection_load"] = time.time() - start_time
|
||||
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒")
|
||||
except Exception as e:
|
||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
entity_extract_start = time.time()
|
||||
query_entities = await self._extract_entities(query)
|
||||
debug(f"提取实体: {query_entities}")
|
||||
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||
|
||||
documents = []
|
||||
all_triplets = []
|
||||
triplet_match_start = time.time()
|
||||
for kb_id in knowledge_base_ids:
|
||||
debug(f"处理知识库: {kb_id}")
|
||||
|
||||
results = collection.query(
|
||||
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'",
|
||||
output_fields=["document_id", "filename", "knowledge_base_id"],
|
||||
limit=100 # 查询足够多的文档以支持后续过滤
|
||||
)
|
||||
if not results:
|
||||
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
|
||||
continue
|
||||
documents.extend(results)
|
||||
|
||||
for doc in results:
|
||||
document_id = doc["document_id"]
|
||||
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
|
||||
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)} 条")
|
||||
matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id)
|
||||
debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)} 条")
|
||||
all_triplets.extend(matched_triplets)
|
||||
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||
|
||||
if not documents:
|
||||
debug("未找到任何有效文档")
|
||||
return []
|
||||
|
||||
info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
|
||||
if not all_triplets:
|
||||
debug("未找到任何匹配的三元组")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
triplet_text_start = time.time()
|
||||
triplet_texts = []
|
||||
for triplet in all_triplets:
|
||||
head = triplet.get('head', '')
|
||||
@ -1029,13 +1041,18 @@ class MilvusConnection:
|
||||
combined_text += " [三元组] " + "; ".join(triplet_texts)
|
||||
debug(
|
||||
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||
|
||||
embedding_start = time.time()
|
||||
embeddings = await self._get_embeddings([combined_text])
|
||||
query_vector = embeddings[0]
|
||||
debug(f"拼接文本向量维度: {len(query_vector)}")
|
||||
timing_stats["embedding_generation"] = time.time() - embedding_start
|
||||
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒")
|
||||
|
||||
search_start = time.time()
|
||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||
|
||||
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||
expr = f"userid == '{userid}' and ({kb_expr})"
|
||||
debug(f"搜索表达式: {expr}")
|
||||
@ -1053,7 +1070,9 @@ class MilvusConnection:
|
||||
)
|
||||
except Exception as e:
|
||||
error(f"向量搜索失败: {str(e)}\n{exception()}")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
|
||||
search_results = []
|
||||
for hits in results:
|
||||
@ -1078,31 +1097,40 @@ class MilvusConnection:
|
||||
|
||||
unique_results = []
|
||||
seen_texts = set()
|
||||
dedup_start = time.time()
|
||||
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
||||
if result['text'] not in seen_texts:
|
||||
unique_results.append(result)
|
||||
seen_texts.add(result['text'])
|
||||
timing_stats["deduplication"] = time.time() - dedup_start
|
||||
debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒")
|
||||
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
||||
|
||||
if use_rerank and unique_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit
|
||||
unique_results = await self._rerank_results(combined_text, unique_results, limit)
|
||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
else:
|
||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
|
||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果")
|
||||
return unique_results[:limit]
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||
|
||||
except Exception as e:
|
||||
error(f"融合搜索失败: {str(e)}\n{exception()}")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
|
||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
||||
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
|
||||
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
|
||||
start_time = time.time() # 记录开始时间
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
timing_stats = {} # 记录各步骤耗时
|
||||
try:
|
||||
info(
|
||||
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||
@ -1133,22 +1161,27 @@ class MilvusConnection:
|
||||
|
||||
if not utility.has_collection(collection_name):
|
||||
debug(f"集合 {collection_name} 不存在")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
debug(f"加载集合: {collection_name}")
|
||||
timing_stats["collection_load"] = time.time() - start_time
|
||||
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒")
|
||||
except Exception as e:
|
||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||
raise RuntimeError(f"加载集合失败: {str(e)}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
embedding_start = time.time()
|
||||
embeddings = await self._get_embeddings([query])
|
||||
query_vector = embeddings[0]
|
||||
debug(f"查询向量维度: {len(query_vector)}")
|
||||
timing_stats["embedding_generation"] = time.time() - embedding_start
|
||||
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒")
|
||||
|
||||
search_start = time.time()
|
||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||
|
||||
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||
expr = f"userid == '{userid}' and ({kb_id_expr})"
|
||||
debug(f"搜索表达式: {expr}")
|
||||
@ -1166,7 +1199,9 @@ class MilvusConnection:
|
||||
)
|
||||
except Exception as e:
|
||||
error(f"搜索失败: {str(e)}\n{exception()}")
|
||||
raise RuntimeError(f"搜索失败: {str(e)}")
|
||||
return {"results": [], "timing": timing_stats}
|
||||
timing_stats["vector_search"] = time.time() - search_start
|
||||
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||
|
||||
search_results = []
|
||||
for hits in results:
|
||||
@ -1189,53 +1224,52 @@ class MilvusConnection:
|
||||
debug(
|
||||
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
|
||||
|
||||
dedup_start = time.time()
|
||||
unique_results = []
|
||||
seen_texts = set()
|
||||
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
||||
if result['text'] not in seen_texts:
|
||||
unique_results.append(result)
|
||||
seen_texts.add(result['text'])
|
||||
timing_stats["deduplication"] = time.time() - dedup_start
|
||||
debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒")
|
||||
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
||||
|
||||
if use_rerank and unique_results:
|
||||
rerank_start = time.time()
|
||||
debug("开始重排序")
|
||||
unique_results = await self._rerank_results(query, unique_results, limit)
|
||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||
timing_stats["reranking"] = time.time() - rerank_start
|
||||
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||
else:
|
||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||
|
||||
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果")
|
||||
return unique_results[:limit]
|
||||
timing_stats["total_time"] = time.time() - start_time
|
||||
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||
|
||||
except Exception as e:
|
||||
error(f"纯向量搜索失败: {str(e)}\n{exception()}")
|
||||
return []
|
||||
return {"results": [], "timing": timing_stats}
|
||||
|
||||
async def list_user_files(self, userid: str) -> List[Dict]:
|
||||
"""根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询"""
|
||||
async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
|
||||
"""列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
|
||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||
try:
|
||||
info(f"开始查询用户文件列表: userid={userid}")
|
||||
info(f"列出用户文件: userid={userid}, db_type={db_type}")
|
||||
|
||||
if not userid:
|
||||
raise ValueError("userid 不能为空")
|
||||
if "_" in userid:
|
||||
raise ValueError("userid 不能包含下划线")
|
||||
if len(userid) > 100:
|
||||
raise ValueError("userid 长度超出限制")
|
||||
if "_" in userid or (db_type and "_" in db_type):
|
||||
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||
if (db_type and len(db_type) > 100) or len(userid) > 100:
|
||||
raise ValueError("userid 或 db_type 的长度超出限制")
|
||||
|
||||
collections = utility.list_collections()
|
||||
collections = [c for c in collections if c.startswith("ragdb")]
|
||||
if not collections:
|
||||
debug("未找到任何 ragdb 开头的集合")
|
||||
return []
|
||||
debug(f"找到集合: {collections}")
|
||||
|
||||
file_list = []
|
||||
seen_files = set()
|
||||
for collection_name in collections:
|
||||
db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else ""
|
||||
debug(f"处理集合: {collection_name}, db_type={db_type}")
|
||||
if not utility.has_collection(collection_name):
|
||||
debug(f"集合 {collection_name} 不存在")
|
||||
return {}
|
||||
|
||||
try:
|
||||
collection = Collection(collection_name)
|
||||
@ -1243,42 +1277,47 @@ class MilvusConnection:
|
||||
debug(f"加载集合: {collection_name}")
|
||||
except Exception as e:
|
||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||
continue
|
||||
return {}
|
||||
|
||||
expr = f"userid == '{userid}'"
|
||||
debug(f"查询表达式: {expr}")
|
||||
|
||||
try:
|
||||
results = collection.query(
|
||||
expr=f"userid == '{userid}'",
|
||||
output_fields=["filename", "file_path", "upload_time", "file_type"],
|
||||
expr=expr,
|
||||
output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"],
|
||||
limit=1000
|
||||
)
|
||||
debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块")
|
||||
except Exception as e:
|
||||
error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}")
|
||||
continue
|
||||
error(f"查询用户文件失败: {str(e)}\n{exception()}")
|
||||
return {}
|
||||
|
||||
files_by_kb = {}
|
||||
seen_document_ids = set()
|
||||
for result in results:
|
||||
filename = result.get("filename")
|
||||
file_path = result.get("file_path")
|
||||
upload_time = result.get("upload_time")
|
||||
file_type = result.get("file_type")
|
||||
if (filename, file_path) not in seen_files:
|
||||
seen_files.add((filename, file_path))
|
||||
file_list.append({
|
||||
"filename": filename,
|
||||
"file_path": file_path,
|
||||
"db_type": db_type,
|
||||
"upload_time": upload_time,
|
||||
"file_type": file_type
|
||||
})
|
||||
debug(
|
||||
f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}")
|
||||
document_id = result.get("document_id")
|
||||
kb_id = result.get("knowledge_base_id")
|
||||
if document_id not in seen_document_ids:
|
||||
seen_document_ids.add(document_id)
|
||||
file_info = {
|
||||
"document_id": document_id,
|
||||
"filename": result.get("filename"),
|
||||
"file_path": result.get("file_path"),
|
||||
"upload_time": result.get("upload_time"),
|
||||
"file_type": result.get("file_type"),
|
||||
"knowledge_base_id": kb_id
|
||||
}
|
||||
if kb_id not in files_by_kb:
|
||||
files_by_kb[kb_id] = []
|
||||
files_by_kb[kb_id].append(file_info)
|
||||
debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}")
|
||||
|
||||
info(f"返回 {len(file_list)} 个文件")
|
||||
return sorted(file_list, key=lambda x: x["upload_time"], reverse=True)
|
||||
info(f"找到 {len(seen_document_ids)} 个文件,userid={userid}, 知识库数量={len(files_by_kb)}")
|
||||
return files_by_kb
|
||||
|
||||
except Exception as e:
|
||||
error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}")
|
||||
return []
|
||||
error(f"列出用户文件失败: {str(e)}\n{exception()}")
|
||||
return {}
|
||||
|
||||
connection_register('Milvus', MilvusConnection)
|
||||
info("MilvusConnection registered")
|
@ -1,25 +1,25 @@
|
||||
import os
|
||||
import torch
|
||||
import re
|
||||
import traceback
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
import logging
|
||||
from appPublic.log import debug, error, warning, info
|
||||
from appPublic.worker import awaitify
|
||||
from base_triple import BaseTripleExtractor, llm_register
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__(model_path)
|
||||
try:
|
||||
logger.debug(f"Loading tokenizer from {model_path}")
|
||||
debug(f"Loading tokenizer from {model_path}")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
logger.debug(f"Loading model from {model_path}")
|
||||
debug(f"Loading model from {model_path}")
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||||
self.device = self.use_mps_if_possible()
|
||||
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
|
||||
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
||||
debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load mREBEL model: {str(e)}")
|
||||
error(f"Failed to load mREBEL model: {str(e)}")
|
||||
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
|
||||
|
||||
self.gen_kwargs = {
|
||||
@ -47,13 +47,13 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
current_chunk = sentence
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
logger.debug(f"Text split into: {len(chunks)} chunks")
|
||||
debug(f"Text split into: {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
def extract_triplets_typed(self, text: str) -> list:
|
||||
"""Parse mREBEL generated text for triplets."""
|
||||
triplets = []
|
||||
logger.debug(f"Raw generated text: {text}")
|
||||
debug(f"Raw generated text: {text}")
|
||||
|
||||
tokens = []
|
||||
in_tag = False
|
||||
@ -77,7 +77,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
|
||||
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||||
tokens = [t for t in tokens if t not in special_tokens and t]
|
||||
logger.debug(f"Processed tokens: {tokens}")
|
||||
debug(f"Processed tokens: {tokens}")
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
@ -96,7 +96,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
'tail': entity2.strip(),
|
||||
'tail_type': type2
|
||||
})
|
||||
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||
debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||
i += 6
|
||||
else:
|
||||
i += 1
|
||||
@ -110,11 +110,11 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
raise ValueError("Text cannot be empty")
|
||||
|
||||
text_chunks = self.split_document(text, max_chunk_size=150)
|
||||
logger.debug(f"Text split into {len(text_chunks)} chunks")
|
||||
debug(f"Text split into {len(text_chunks)} chunks")
|
||||
|
||||
all_triplets = []
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||
debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
chunk,
|
||||
@ -128,17 +128,17 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
generated_tokens = self.model.generate(
|
||||
model_inputs["input_ids"],
|
||||
attention_mask=model_inputs["attention_mask"],
|
||||
**self.gen_kwargs,
|
||||
**self.gen_kwargs
|
||||
)
|
||||
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||||
for idx, sentence in enumerate(decoded_preds):
|
||||
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
|
||||
debug(f"Chunk {i + 1} generated text: {sentence}")
|
||||
triplets = self.extract_triplets_typed(sentence)
|
||||
if triplets:
|
||||
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
||||
debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
||||
all_triplets.extend(triplets)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
|
||||
warning(f"Error processing chunk {i + 1}: {str(e)}")
|
||||
continue
|
||||
|
||||
unique_triplets = []
|
||||
@ -149,13 +149,12 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
||||
seen.add(identifier)
|
||||
unique_triplets.append(t)
|
||||
|
||||
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
|
||||
info(f"Extracted {len(unique_triplets)} unique triplets")
|
||||
return unique_triplets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract triplets: {str(e)}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
error(f"Failed to extract triplets: {str(e)}")
|
||||
debug(f"Traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
llm_register("mrebel-large", MRebelTripleExtractor)
|
Loading…
Reference in New Issue
Block a user