增加数据库服务化基于数据库+知识图谱召回文本块、列出用户不同知识库中的文档功能
This commit is contained in:
parent
1c783461bd
commit
564bddfcde
@ -80,25 +80,43 @@ data: {
|
|||||||
"use_rerank": true
|
"use_rerank": true
|
||||||
}
|
}
|
||||||
response:
|
response:
|
||||||
- Success: HTTP 200, [
|
- Success: HTTP 200, {
|
||||||
{
|
"status": "success",
|
||||||
"text": "<完整文本内容>",
|
"results": [
|
||||||
"distance": 0.95,
|
{
|
||||||
"source": "fused_query_with_triplets",
|
"text": "<完整文本内容>",
|
||||||
"rerank_score": 0.92, // 若 use_rerank=true
|
"distance": 0.95,
|
||||||
"metadata": {
|
"source": "fused_query_with_triplets",
|
||||||
"userid": "user1",
|
"rerank_score": 0.92, // 若 use_rerank=true
|
||||||
"document_id": "<uuid>",
|
"metadata": {
|
||||||
"filename": "file.txt",
|
"userid": "user1",
|
||||||
"file_path": "/path/to/file.txt",
|
"document_id": "<uuid>",
|
||||||
"upload_time": "<iso_timestamp>",
|
"filename": "file.txt",
|
||||||
"file_type": "txt"
|
"file_path": "/path/to/file.txt",
|
||||||
}
|
"upload_time": "<iso_timestamp>",
|
||||||
|
"file_type": "txt"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"timing": {
|
||||||
|
"collection_load": <float>, // 集合加载耗时(秒)
|
||||||
|
"entity_extraction": <float>, // 实体提取耗时(秒)
|
||||||
|
"triplet_matching": <float>, // 三元组匹配耗时(秒)
|
||||||
|
"triplet_text_combine": <float>, // 拼接三元组文本耗时(秒)
|
||||||
|
"embedding_generation": <float>, // 嵌入向量生成耗时(秒)
|
||||||
|
"vector_search": <float>, // 向量搜索耗时(秒)
|
||||||
|
"deduplication": <float>, // 去重耗时(秒)
|
||||||
|
"reranking": <float>, // 重排序耗时(秒,若 use_rerank=true)
|
||||||
|
"total_time": <float> // 总耗时(秒)
|
||||||
},
|
},
|
||||||
...
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
]
|
}
|
||||||
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"}
|
- Error: HTTP 400, {
|
||||||
|
"status": "error",
|
||||||
|
"message": "<error message>",
|
||||||
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
|
}
|
||||||
6. Search Query Endpoint:
|
6. Search Query Endpoint:
|
||||||
path: /v1/searchquery
|
path: /v1/searchquery
|
||||||
method: POST
|
method: POST
|
||||||
@ -113,45 +131,83 @@ data: {
|
|||||||
"use_rerank": true
|
"use_rerank": true
|
||||||
}
|
}
|
||||||
response:
|
response:
|
||||||
- Success: HTTP 200, [
|
- Success: HTTP 200, {
|
||||||
{
|
"status": "success",
|
||||||
"text": "<完整文本内容>",
|
"results": [
|
||||||
"distance": 0.95,
|
{
|
||||||
"source": "vector_query",
|
"text": "<完整文本内容>",
|
||||||
"rerank_score": 0.92, // 若 use_rerank=true
|
"distance": 0.95,
|
||||||
"metadata": {
|
"source": "vector_query",
|
||||||
"userid": "user1",
|
"rerank_score": 0.92, // 若 use_rerank=true
|
||||||
"document_id": "<uuid>",
|
"metadata": {
|
||||||
"filename": "file.txt",
|
"userid": "user1",
|
||||||
"file_path": "/path/to/file.txt",
|
"document_id": "<uuid>",
|
||||||
"upload_time": "<iso_timestamp>",
|
"filename": "file.txt",
|
||||||
"file_type": "txt"
|
"file_path": "/path/to/file.txt",
|
||||||
}
|
"upload_time": "<iso_timestamp>",
|
||||||
|
"file_type": "txt"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"timing": {
|
||||||
|
"collection_load": <float>, // 集合加载耗时(秒)
|
||||||
|
"embedding_generation": <float>, // 嵌入向量生成耗时(秒)
|
||||||
|
"vector_search": <float>, // 向量搜索耗时(秒)
|
||||||
|
"deduplication": <float>, // 去重耗时(秒)
|
||||||
|
"reranking": <float>, // 重排序耗时(秒,若 use_rerank=true)
|
||||||
|
"total_time": <float> // 总耗时(秒)
|
||||||
},
|
},
|
||||||
...
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
]
|
}
|
||||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
- Error: HTTP 400, {
|
||||||
|
"status": "error",
|
||||||
|
"message": "<error message>",
|
||||||
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
|
}
|
||||||
|
|
||||||
7. List User Files Endpoint:
|
7. List User Files Endpoint:
|
||||||
path: /v1/listuserfiles
|
path: /v1/listuserfiles
|
||||||
method: POST
|
method: POST
|
||||||
headers: {"Content-Type": "application/json"}
|
headers: {"Content-Type": "application/json"}
|
||||||
data: {
|
data: {
|
||||||
"userid": "testuser2"
|
"userid": "user1",
|
||||||
|
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||||
}
|
}
|
||||||
response:
|
response:
|
||||||
- Success: HTTP 200, [
|
- Success: HTTP 200, {
|
||||||
{
|
"status": "success",
|
||||||
"filename": "file.txt",
|
"files_by_knowledge_base": {
|
||||||
"file_path": "/path/to/file.txt",
|
"kb123": [
|
||||||
"db_type": "textdb",
|
{
|
||||||
"upload_time": "<iso_timestamp>",
|
"document_id": "<uuid>",
|
||||||
"file_type": "txt"
|
"filename": "file1.txt",
|
||||||
|
"file_path": "/path/to/file1.txt",
|
||||||
|
"upload_time": "<iso_timestamp>",
|
||||||
|
"file_type": "txt",
|
||||||
|
"knowledge_base_id": "kb123"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"kb456": [
|
||||||
|
{
|
||||||
|
"document_id": "<uuid>",
|
||||||
|
"filename": "file2.pdf",
|
||||||
|
"file_path": "/path/to/file2.pdf",
|
||||||
|
"upload_time": "<iso_timestamp>",
|
||||||
|
"file_type": "pdf",
|
||||||
|
"knowledge_base_id": "kb456"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
},
|
},
|
||||||
...
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
]
|
}
|
||||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
- Error: HTTP 400, {
|
||||||
|
"status": "error",
|
||||||
|
"message": "<error message>",
|
||||||
|
"collection_name": "ragdb" or "ragdb_textdb"
|
||||||
|
}
|
||||||
8. Connection Endpoint (for compatibility):
|
8. Connection Endpoint (for compatibility):
|
||||||
path: /v1/connection
|
path: /v1/connection
|
||||||
method: POST
|
method: POST
|
||||||
@ -378,7 +434,13 @@ async def fused_search_query(request, params_kw, *params, **kw):
|
|||||||
"use_rerank": use_rerank
|
"use_rerank": use_rerank
|
||||||
})
|
})
|
||||||
debug(f'{result=}')
|
debug(f'{result=}')
|
||||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
response = {
|
||||||
|
"status": "success",
|
||||||
|
"results": result.get("results", []),
|
||||||
|
"timing": result.get("timing", {}),
|
||||||
|
"collection_name": collection_name
|
||||||
|
}
|
||||||
|
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'融合搜索失败: {str(e)}')
|
error(f'融合搜索失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
@ -396,7 +458,7 @@ async def search_query(request, params_kw, *params, **kw):
|
|||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||||
limit = params_kw.get('limit')
|
limit = params_kw.get('limit')
|
||||||
offset = params_kw.get('offset')
|
offset = params_kw.get('offset', 0)
|
||||||
use_rerank = params_kw.get('use_rerank', True)
|
use_rerank = params_kw.get('use_rerank', True)
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
@ -417,7 +479,13 @@ async def search_query(request, params_kw, *params, **kw):
|
|||||||
"use_rerank": use_rerank
|
"use_rerank": use_rerank
|
||||||
})
|
})
|
||||||
debug(f'{result=}')
|
debug(f'{result=}')
|
||||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
response = {
|
||||||
|
"status": "success",
|
||||||
|
"results": result.get("results", []),
|
||||||
|
"timing": result.get("timing", {}),
|
||||||
|
"collection_name": collection_name
|
||||||
|
}
|
||||||
|
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'纯向量搜索失败: {str(e)}')
|
error(f'纯向量搜索失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
@ -431,23 +499,33 @@ async def list_user_files(request, params_kw, *params, **kw):
|
|||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.engine
|
engine = se.engine
|
||||||
userid = params_kw.get('userid')
|
userid = params_kw.get('userid')
|
||||||
if not userid:
|
db_type = params_kw.get('db_type', '')
|
||||||
debug(f'userid 未提供')
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
return web.json_response({
|
|
||||||
"status": "error",
|
|
||||||
"message": "userid 参数未提供"
|
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
|
||||||
try:
|
try:
|
||||||
|
if not userid:
|
||||||
|
debug(f'userid 未提供')
|
||||||
|
return web.json_response({
|
||||||
|
"status": "error",
|
||||||
|
"message": "userid 未提供",
|
||||||
|
"collection_name": collection_name
|
||||||
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
result = await engine.handle_connection("list_user_files", {
|
result = await engine.handle_connection("list_user_files", {
|
||||||
"userid": userid
|
"userid": userid,
|
||||||
|
"db_type": db_type
|
||||||
})
|
})
|
||||||
debug(f'{result=}')
|
debug(f'{result=}')
|
||||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
response = {
|
||||||
|
"status": "success",
|
||||||
|
"files_by_knowledge_base": result,
|
||||||
|
"collection_name": collection_name
|
||||||
|
}
|
||||||
|
return web.json_response(response, dumps=lambda obj: json.dumps(obj, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f'查询用户文件列表失败: {str(e)}')
|
error(f'列出用户文件失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e)
|
"message": str(e),
|
||||||
|
"collection_name": collection_name
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
async def handle_connection(request, params_kw, *params, **kw):
|
async def handle_connection(request, params_kw, *params, **kw):
|
||||||
|
@ -45,6 +45,7 @@ def init():
|
|||||||
rf = RegisterFunction()
|
rf = RegisterFunction()
|
||||||
rf.register('entities', entities)
|
rf.register('entities', entities)
|
||||||
rf.register('docs', docs)
|
rf.register('docs', docs)
|
||||||
|
debug("注册路由: entities, docs")
|
||||||
|
|
||||||
async def docs(request, params_kw, *params, **kw):
|
async def docs(request, params_kw, *params, **kw):
|
||||||
return helptext
|
return helptext
|
||||||
@ -53,12 +54,11 @@ async def entities(request, params_kw, *params, **kw):
|
|||||||
debug(f'{params_kw.query=}')
|
debug(f'{params_kw.query=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
engine = se.engine
|
engine = se.engine
|
||||||
f = awaitify(engine.extract_entities)
|
|
||||||
query = params_kw.query
|
query = params_kw.query
|
||||||
if query is None:
|
if query is None:
|
||||||
e = exception(f'query is None')
|
e = exception(f'query is None')
|
||||||
raise e
|
raise e
|
||||||
entities = await f(query)
|
entities = await engine.extract_entities(query)
|
||||||
debug(f'{entities=}, type(entities)')
|
debug(f'{entities=}, type(entities)')
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
125
llmengine/kgc.py
125
llmengine/kgc.py
@ -1,19 +1,23 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from py2neo import Graph, Node, Relationship
|
from py2neo import Graph, Node, Relationship
|
||||||
from typing import Set, List, Dict, Tuple
|
from typing import Set, List, Dict, Tuple
|
||||||
|
from appPublic.jsonConfig import getConfig
|
||||||
|
from appPublic.log import debug, error, info, exception
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class KnowledgeGraph:
|
class KnowledgeGraph:
|
||||||
def __init__(self, triples: List[Dict], document_id: str):
|
def __init__(self, triples: List[Dict], document_id: str, knowledge_base_id: str, userid: str):
|
||||||
self.triples = triples
|
self.triples = triples
|
||||||
self.document_id = document_id
|
self.document_id = document_id
|
||||||
self.g = Graph("bolt://10.18.34.18:7687", auth=('neo4j', '261229..wmh'))
|
self.knowledge_base_id = knowledge_base_id
|
||||||
logger.info(f"开始构建知识图谱,document_id: {self.document_id}, 三元组数量: {len(triples)}")
|
self.userid = userid
|
||||||
|
config = getConfig()
|
||||||
|
self.neo4j_uri = config['neo4j']['uri']
|
||||||
|
self.neo4j_user = config['neo4j']['user']
|
||||||
|
self.neo4j_password = config['neo4j']['password']
|
||||||
|
self.g = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
|
||||||
|
info(f"开始构建知识图谱,document_id: {self.document_id}, knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid}, 三元组数量: {len(triples)}")
|
||||||
|
|
||||||
def _normalize_label(self, entity_type: str) -> str:
|
def _normalize_label(self, entity_type: str) -> str:
|
||||||
"""规范化实体类型为 Neo4j 标签"""
|
"""规范化实体类型为 Neo4j 标签"""
|
||||||
@ -25,28 +29,29 @@ class KnowledgeGraph:
|
|||||||
return label or 'Entity'
|
return label or 'Entity'
|
||||||
|
|
||||||
def _clean_relation(self, relation: str) -> Tuple[str, str]:
|
def _clean_relation(self, relation: str) -> Tuple[str, str]:
|
||||||
"""清洗关系,返回 (rel_type, rel_name)"""
|
"""清洗关系,返回 (rel_type, rel_name),确保 rel_type 合法"""
|
||||||
relation = relation.strip()
|
relation = relation.strip()
|
||||||
if not relation:
|
if not relation:
|
||||||
return 'RELATED_TO', '相关'
|
return 'RELATED_TO', '相关'
|
||||||
if relation.startswith('<') and relation.endswith('>'):
|
|
||||||
cleaned_relation = relation[1:-1]
|
cleaned_relation = re.sub(r'[^\w\s]', '', relation).strip()
|
||||||
rel_name = cleaned_relation
|
if not cleaned_relation:
|
||||||
rel_type = re.sub(r'[^\w\s]', '', cleaned_relation).replace(' ', '_').upper()
|
return 'RELATED_TO', '相关'
|
||||||
else:
|
|
||||||
rel_name = relation
|
if 'instance of' in relation.lower():
|
||||||
rel_type = re.sub(r'[^\w\s]', '', relation).replace(' ', '_').upper()
|
return 'INSTANCE_OF', '实例'
|
||||||
if 'instance of' in relation.lower():
|
elif 'subclass of' in relation.lower():
|
||||||
rel_type = 'INSTANCE_OF'
|
return 'SUBCLASS_OF', '子类'
|
||||||
rel_name = '实例'
|
elif 'part of' in relation.lower():
|
||||||
elif 'subclass of' in relation.lower():
|
return 'PART_OF', '部分'
|
||||||
rel_type = 'SUBCLASS_OF'
|
|
||||||
rel_name = '子类'
|
rel_type = re.sub(r'\s+', '_', cleaned_relation).upper()
|
||||||
elif 'part of' in relation.lower():
|
if rel_type and rel_type[0].isdigit():
|
||||||
rel_type = 'PART_OF'
|
rel_type = f'REL_{rel_type}'
|
||||||
rel_name = '部分'
|
if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', rel_type):
|
||||||
logger.debug(f"处理关系: {relation} -> {rel_type} ({rel_name})")
|
debug(f"非法关系类型 '{rel_type}',替换为 'RELATED_TO'")
|
||||||
return rel_type, rel_name
|
return 'RELATED_TO', relation
|
||||||
|
return rel_type, relation
|
||||||
|
|
||||||
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
|
def read_nodes(self) -> Tuple[Dict[str, Set], Dict[str, List], List[Dict]]:
|
||||||
"""从三元组列表中读取节点和关系"""
|
"""从三元组列表中读取节点和关系"""
|
||||||
@ -57,14 +62,14 @@ class KnowledgeGraph:
|
|||||||
try:
|
try:
|
||||||
for triple in self.triples:
|
for triple in self.triples:
|
||||||
if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']):
|
if not all(key in triple for key in ['head', 'head_type', 'type', 'tail', 'tail_type']):
|
||||||
logger.warning(f"无效三元组: {triple}")
|
debug(f"无效三元组: {triple}")
|
||||||
continue
|
continue
|
||||||
head, relation, tail, head_type, tail_type = (
|
head, relation, tail, head_type, tail_type = (
|
||||||
triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type']
|
triple['head'], triple['type'], triple['tail'], triple['head_type'], triple['tail_type']
|
||||||
)
|
)
|
||||||
head_label = self._normalize_label(head_type)
|
head_label = self._normalize_label(head_type)
|
||||||
tail_label = self._normalize_label(tail_type)
|
tail_label = self._normalize_label(tail_type)
|
||||||
logger.debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
|
debug(f"实体类型: {head_type} -> {head_label}, {tail_type} -> {tail_label}")
|
||||||
|
|
||||||
if head_label not in nodes_by_label:
|
if head_label not in nodes_by_label:
|
||||||
nodes_by_label[head_label] = set()
|
nodes_by_label[head_label] = set()
|
||||||
@ -92,33 +97,44 @@ class KnowledgeGraph:
|
|||||||
'tail_type': tail_type
|
'tail_type': tail_type
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个")
|
info(f"读取节点: {sum(len(nodes) for nodes in nodes_by_label.values())} 个")
|
||||||
logger.info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条")
|
info(f"读取关系: {sum(len(rels) for rels in relations_by_type.values())} 条")
|
||||||
return nodes_by_label, relations_by_type, triples
|
return nodes_by_label, relations_by_type, triples
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取三元组失败: {str(e)}")
|
error(f"读取三元组失败: {str(e)}")
|
||||||
raise RuntimeError(f"读取三元组失败: {str(e)}")
|
raise RuntimeError(f"读取三元组失败: {str(e)}")
|
||||||
|
|
||||||
def create_node(self, label: str, nodes: Set[str]):
|
def create_node(self, label: str, nodes: Set[str]):
|
||||||
"""创建节点,包含 document_id 属性"""
|
"""创建节点,包含 document_id、knowledge_base_id 和 userid 属性"""
|
||||||
count = 0
|
count = 0
|
||||||
for node_name in nodes:
|
for node_name in nodes:
|
||||||
query = f"MATCH (n:{label} {{name: '{node_name}', document_id: '{self.document_id}'}}) RETURN n"
|
query = (
|
||||||
|
f"MATCH (n:{label} {{name: $name, document_id: $doc_id, "
|
||||||
|
f"knowledge_base_id: $kb_id, userid: $userid}}) RETURN n"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if self.g.run(query).data():
|
if self.g.run(query, name=node_name, doc_id=self.document_id,
|
||||||
|
kb_id=self.knowledge_base_id, userid=self.userid).data():
|
||||||
continue
|
continue
|
||||||
node = Node(label, name=node_name, document_id=self.document_id)
|
node = Node(
|
||||||
|
label,
|
||||||
|
name=node_name,
|
||||||
|
document_id=self.document_id,
|
||||||
|
knowledge_base_id=self.knowledge_base_id,
|
||||||
|
userid=self.userid
|
||||||
|
)
|
||||||
self.g.create(node)
|
self.g.create(node)
|
||||||
count += 1
|
count += 1
|
||||||
logger.debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id})")
|
debug(f"创建节点: {label} - {node_name} (document_id: {self.document_id}, "
|
||||||
|
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
|
error(f"创建节点失败: {label} - {node_name}, 错误: {str(e)}")
|
||||||
logger.info(f"创建 {label} 节点: {count}/{len(nodes)} 个")
|
info(f"创建 {label} 节点: {count}/{len(nodes)} 个")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def create_relationship(self, rel_type: str, relations: List[Dict]):
|
def create_relationship(self, rel_type: str, relations: List[Dict]):
|
||||||
"""创建关系"""
|
"""创建关系,包含 document_id、knowledge_base_id 和 userid 属性"""
|
||||||
count = 0
|
count = 0
|
||||||
total = len(relations)
|
total = len(relations)
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
@ -132,17 +148,23 @@ class KnowledgeGraph:
|
|||||||
seen_edges.add(edge_key)
|
seen_edges.add(edge_key)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (p:{head_label} {{name: '{head}', document_id: '{self.document_id}'}}), "
|
f"MATCH (p:{head_label} {{name: $head, document_id: $doc_id, "
|
||||||
f"(q:{tail_label} {{name: '{tail}', document_id: '{self.document_id}'}}) "
|
f"knowledge_base_id: $kb_id, userid: $userid}}), "
|
||||||
f"CREATE (p)-[r:{rel_type} {{name: '{rel_name}'}}]->(q)"
|
f"(q:{tail_label} {{name: $tail, document_id: $doc_id, "
|
||||||
|
f"knowledge_base_id: $kb_id, userid: $userid}}) "
|
||||||
|
f"CREATE (p)-[r:{rel_type} {{name: $rel_name, document_id: $doc_id, "
|
||||||
|
f"knowledge_base_id: $kb_id, userid: $userid}}]->(q)"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.g.run(query)
|
self.g.run(query, head=head, tail=tail, rel_name=rel_name,
|
||||||
|
doc_id=self.document_id, kb_id=self.knowledge_base_id,
|
||||||
|
userid=self.userid)
|
||||||
count += 1
|
count += 1
|
||||||
logger.debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id})")
|
debug(f"创建关系: {head} -[{rel_type}]-> {tail} (document_id: {self.document_id}, "
|
||||||
|
f"knowledge_base_id: {self.knowledge_base_id}, userid: {self.userid})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建关系失败: {query}, 错误: {str(e)}")
|
error(f"创建关系失败: {query}, 错误: {str(e)}")
|
||||||
logger.info(f"创建 {rel_type} 关系: {count}/{total} 条")
|
info(f"创建 {rel_type} 关系: {count}/{total} 条")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def create_graphnodes(self):
|
def create_graphnodes(self):
|
||||||
@ -151,7 +173,7 @@ class KnowledgeGraph:
|
|||||||
total = 0
|
total = 0
|
||||||
for label, nodes in nodes_by_label.items():
|
for label, nodes in nodes_by_label.items():
|
||||||
total += self.create_node(label, nodes)
|
total += self.create_node(label, nodes)
|
||||||
logger.info(f"总计创建节点: {total} 个")
|
info(f"总计创建节点: {total} 个")
|
||||||
return total
|
return total
|
||||||
|
|
||||||
def create_graphrels(self):
|
def create_graphrels(self):
|
||||||
@ -160,15 +182,16 @@ class KnowledgeGraph:
|
|||||||
total = 0
|
total = 0
|
||||||
for rel_type, relations in relations_by_type.items():
|
for rel_type, relations in relations_by_type.items():
|
||||||
total += self.create_relationship(rel_type, relations)
|
total += self.create_relationship(rel_type, relations)
|
||||||
logger.info(f"总计创建关系: {total} 条")
|
info(f"总计创建关系: {total} 条")
|
||||||
return total
|
return total
|
||||||
|
|
||||||
def export_data(self):
|
def export_data(self):
|
||||||
"""导出节点到文件,包含 document_id"""
|
"""导出节点到文件,包含 document_id、knowledge_base_id 和 userid"""
|
||||||
nodes_by_label, _, _ = self.read_nodes()
|
nodes_by_label, _, _ = self.read_nodes()
|
||||||
os.makedirs('dict', exist_ok=True)
|
os.makedirs('dict', exist_ok=True)
|
||||||
for label, nodes in nodes_by_label.items():
|
for label, nodes in nodes_by_label.items():
|
||||||
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
|
with open(f'dict/{label.lower()}.txt', 'w', encoding='utf-8') as f:
|
||||||
f.write('\n'.join(f"{name}\t{self.document_id}" for name in sorted(nodes)))
|
f.write('\n'.join(f"{name}\t{self.document_id}\t{self.knowledge_base_id}\t{self.userid}"
|
||||||
logger.info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个")
|
for name in sorted(nodes)))
|
||||||
|
info(f"导出 {label} 节点到 dict/{label.lower()}.txt: {len(nodes)} 个")
|
||||||
return
|
return
|
@ -1,11 +1,9 @@
|
|||||||
# Requires ltp>=0.2.0
|
|
||||||
|
|
||||||
from ltp import LTP
|
from ltp import LTP
|
||||||
from typing import List
|
from typing import List
|
||||||
import logging
|
from appPublic.log import debug, info, error
|
||||||
|
from appPublic.worker import awaitify
|
||||||
from llmengine.base_entity import BaseLtp, ltp_register
|
from llmengine.base_entity import BaseLtp, ltp_register
|
||||||
|
import asyncio
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class LtpEntity(BaseLtp):
|
class LtpEntity(BaseLtp):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
@ -14,7 +12,7 @@ class LtpEntity(BaseLtp):
|
|||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model_name = model_id.split('/')[-1]
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
def extract_entities(self, query: str) -> List[str]:
|
async def extract_entities(self, query: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
从查询文本中抽取实体,包括:
|
从查询文本中抽取实体,包括:
|
||||||
- LTP NER 识别的实体(所有类型)。
|
- LTP NER 识别的实体(所有类型)。
|
||||||
@ -26,7 +24,18 @@ class LtpEntity(BaseLtp):
|
|||||||
if not query:
|
if not query:
|
||||||
raise ValueError("查询文本不能为空")
|
raise ValueError("查询文本不能为空")
|
||||||
|
|
||||||
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
# 定义同步 pipeline 函数,正确传递 tasks 参数
|
||||||
|
def sync_pipeline(query, tasks):
|
||||||
|
return self.ltp.pipeline([query], tasks=tasks)
|
||||||
|
|
||||||
|
# 使用 run_in_executor 运行同步 pipeline
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: sync_pipeline(query, ["cws", "pos", "ner"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
words = result.cws[0]
|
words = result.cws[0]
|
||||||
pos_list = result.pos[0]
|
pos_list = result.pos[0]
|
||||||
ner = result.ner[0]
|
ner = result.ner[0]
|
||||||
@ -34,7 +43,7 @@ class LtpEntity(BaseLtp):
|
|||||||
entities = []
|
entities = []
|
||||||
subword_set = set()
|
subword_set = set()
|
||||||
|
|
||||||
logger.debug(f"NER 结果: {ner}")
|
debug(f"NER 结果: {ner}")
|
||||||
for entity_type, entity, start, end in ner:
|
for entity_type, entity, start, end in ner:
|
||||||
entities.append(entity)
|
entities.append(entity)
|
||||||
|
|
||||||
@ -49,13 +58,13 @@ class LtpEntity(BaseLtp):
|
|||||||
if combined:
|
if combined:
|
||||||
entities.append(combined)
|
entities.append(combined)
|
||||||
subword_set.update(combined_words)
|
subword_set.update(combined_words)
|
||||||
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||||
combined = ""
|
combined = ""
|
||||||
combined_words = []
|
combined_words = []
|
||||||
else:
|
else:
|
||||||
combined = ""
|
combined = ""
|
||||||
combined_words = []
|
combined_words = []
|
||||||
logger.debug(f"连续名词子词集合: {subword_set}")
|
debug(f"连续名词子词集合: {subword_set}")
|
||||||
|
|
||||||
for word, pos in zip(words, pos_list):
|
for word, pos in zip(words, pos_list):
|
||||||
if pos == 'n' and word not in subword_set:
|
if pos == 'n' and word not in subword_set:
|
||||||
@ -66,11 +75,11 @@ class LtpEntity(BaseLtp):
|
|||||||
entities.append(word)
|
entities.append(word)
|
||||||
|
|
||||||
unique_entities = list(dict.fromkeys(entities))
|
unique_entities = list(dict.fromkeys(entities))
|
||||||
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||||
return unique_entities
|
return unique_entities
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"实体抽取失败: {str(e)}")
|
error(f"实体抽取失败: {str(e)}")
|
||||||
return []
|
raise # 抛出异常以便调试,而不是返回空列表
|
||||||
|
|
||||||
ltp_register('LTP', LtpEntity)
|
ltp_register('LTP', LtpEntity)
|
@ -16,6 +16,7 @@ from llmengine.kgc import KnowledgeGraph
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from py2neo import Graph
|
from py2neo import Graph
|
||||||
from scipy.spatial.distance import cosine
|
from scipy.spatial.distance import cosine
|
||||||
|
import time
|
||||||
|
|
||||||
# 嵌入缓存
|
# 嵌入缓存
|
||||||
EMBED_CACHE = {}
|
EMBED_CACHE = {}
|
||||||
@ -182,7 +183,7 @@ class MilvusConnection:
|
|||||||
if not userid:
|
if not userid:
|
||||||
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
|
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
|
||||||
"document_id": "", "status_code": 400}
|
"document_id": "", "status_code": 400}
|
||||||
return await self.list_user_files(userid)
|
return await self._list_user_files(userid)
|
||||||
else:
|
else:
|
||||||
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
|
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
|
||||||
"document_id": "", "status_code": 400}
|
"document_id": "", "status_code": 400}
|
||||||
@ -814,20 +815,24 @@ class MilvusConnection:
|
|||||||
error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
|
error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
|
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, knowledge_base_id: str) -> List[Dict]:
|
||||||
"""匹配查询实体与 Neo4j 中的三元组"""
|
"""匹配查询实体与 Neo4j 中的三元组"""
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
matched_triplets = []
|
matched_triplets = []
|
||||||
ENTITY_SIMILARITY_THRESHOLD = 0.8
|
ENTITY_SIMILARITY_THRESHOLD = 0.8
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
|
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
|
||||||
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
|
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
|
||||||
|
neo4j_connect_time = time.time() - start_time
|
||||||
|
debug(f"Neo4j 连接耗时: {neo4j_connect_time:.3f} 秒")
|
||||||
|
|
||||||
matched_names = set()
|
matched_names = set()
|
||||||
|
entity_match_start = time.time()
|
||||||
for entity in query_entities:
|
for entity in query_entities:
|
||||||
normalized_entity = entity.lower().strip()
|
normalized_entity = entity.lower().strip()
|
||||||
query = """
|
query = """
|
||||||
MATCH (n {document_id: $document_id})
|
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
|
||||||
WHERE toLower(n.name) CONTAINS $entity
|
WHERE toLower(n.name) CONTAINS $entity
|
||||||
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
|
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
|
||||||
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
|
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
|
||||||
@ -835,24 +840,27 @@ class MilvusConnection:
|
|||||||
LIMIT 100
|
LIMIT 100
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = graph.run(query, document_id=document_id, entity=normalized_entity).data()
|
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, entity=normalized_entity).data()
|
||||||
for record in results:
|
for record in results:
|
||||||
matched_names.add(record['n.name'])
|
matched_names.add(record['n.name'])
|
||||||
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
|
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
|
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
|
||||||
continue
|
continue
|
||||||
|
entity_match_time = time.time() - entity_match_start
|
||||||
|
debug(f"实体匹配耗时: {entity_match_time:.3f} 秒")
|
||||||
|
|
||||||
triplets = []
|
triplets = []
|
||||||
if matched_names:
|
if matched_names:
|
||||||
|
triplet_query_start = time.time()
|
||||||
query = """
|
query = """
|
||||||
MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id})
|
MATCH (h {userid: $userid, knowledge_base_id: $knowledge_base_id})-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->(t {userid: $userid, knowledge_base_id: $knowledge_base_id})
|
||||||
WHERE h.name IN $matched_names OR t.name IN $matched_names
|
WHERE h.name IN $matched_names OR t.name IN $matched_names
|
||||||
RETURN h.name AS head, r.name AS type, t.name AS tail
|
RETURN h.name AS head, r.name AS type, t.name AS tail
|
||||||
LIMIT 100
|
LIMIT 100
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data()
|
results = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id, matched_names=list(matched_names)).data()
|
||||||
seen = set()
|
seen = set()
|
||||||
for record in results:
|
for record in results:
|
||||||
head, type_, tail = record['head'], record['type'], record['tail']
|
head, type_, tail = record['head'], record['type'], record['tail']
|
||||||
@ -866,22 +874,28 @@ class MilvusConnection:
|
|||||||
'head_type': '',
|
'head_type': '',
|
||||||
'tail_type': ''
|
'tail_type': ''
|
||||||
})
|
})
|
||||||
debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}")
|
debug(f"从 Neo4j 加载三元组: knowledge_base_id={knowledge_base_id}, 数量={len(triplets)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}")
|
error(f"检索三元组失败: knowledge_base_id={knowledge_base_id}, 错误: {str(e)}\n{exception()}")
|
||||||
return []
|
return []
|
||||||
|
triplet_query_time = time.time() - triplet_query_start
|
||||||
|
debug(f"Neo4j 三元组查询耗时: {triplet_query_time:.3f} 秒")
|
||||||
|
|
||||||
if not triplets:
|
if not triplets:
|
||||||
debug(f"文档 document_id={document_id} 无匹配三元组")
|
debug(f"知识库 knowledge_base_id={knowledge_base_id} 无匹配三元组")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
embedding_start = time.time()
|
||||||
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
|
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
|
||||||
embeddings = await self._get_embeddings(texts_to_embed)
|
embeddings = await self._get_embeddings(texts_to_embed)
|
||||||
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
|
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
|
||||||
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
|
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
|
||||||
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
|
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
|
||||||
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)")
|
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)")
|
||||||
|
embedding_time = time.time() - embedding_start
|
||||||
|
debug(f"嵌入向量生成耗时: {embedding_time:.3f} 秒")
|
||||||
|
|
||||||
|
similarity_start = time.time()
|
||||||
for entity in query_entities:
|
for entity in query_entities:
|
||||||
entity_vec = entity_vectors[entity]
|
entity_vec = entity_vectors[entity]
|
||||||
for d_triplet in triplets:
|
for d_triplet in triplets:
|
||||||
@ -894,6 +908,8 @@ class MilvusConnection:
|
|||||||
matched_triplets.append(d_triplet)
|
matched_triplets.append(d_triplet)
|
||||||
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
|
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
|
||||||
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
|
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
|
||||||
|
similarity_time = time.time() - similarity_start
|
||||||
|
debug(f"相似度计算耗时: {similarity_time:.3f} 秒")
|
||||||
|
|
||||||
unique_matched = []
|
unique_matched = []
|
||||||
seen = set()
|
seen = set()
|
||||||
@ -903,6 +919,8 @@ class MilvusConnection:
|
|||||||
seen.add(identifier)
|
seen.add(identifier)
|
||||||
unique_matched.append(t)
|
unique_matched.append(t)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
debug(f"_match_triplets 总耗时: {total_time:.3f} 秒")
|
||||||
info(f"找到 {len(unique_matched)} 个匹配的三元组")
|
info(f"找到 {len(unique_matched)} 个匹配的三元组")
|
||||||
return unique_matched
|
return unique_matched
|
||||||
|
|
||||||
@ -957,9 +975,11 @@ class MilvusConnection:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
|
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
|
||||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
|
||||||
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
|
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
|
timing_stats = {} # 记录各步骤耗时
|
||||||
try:
|
try:
|
||||||
info(
|
info(
|
||||||
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||||
@ -975,46 +995,38 @@ class MilvusConnection:
|
|||||||
|
|
||||||
if not utility.has_collection(collection_name):
|
if not utility.has_collection(collection_name):
|
||||||
debug(f"集合 {collection_name} 不存在")
|
debug(f"集合 {collection_name} 不存在")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
collection = Collection(collection_name)
|
collection = Collection(collection_name)
|
||||||
collection.load()
|
collection.load()
|
||||||
debug(f"加载集合: {collection_name}")
|
debug(f"加载集合: {collection_name}")
|
||||||
|
timing_stats["collection_load"] = time.time() - start_time
|
||||||
|
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
|
entity_extract_start = time.time()
|
||||||
query_entities = await self._extract_entities(query)
|
query_entities = await self._extract_entities(query)
|
||||||
debug(f"提取实体: {query_entities}")
|
timing_stats["entity_extraction"] = time.time() - entity_extract_start
|
||||||
|
debug(f"提取实体: {query_entities}, 耗时: {timing_stats['entity_extraction']:.3f} 秒")
|
||||||
|
|
||||||
documents = []
|
|
||||||
all_triplets = []
|
all_triplets = []
|
||||||
|
triplet_match_start = time.time()
|
||||||
for kb_id in knowledge_base_ids:
|
for kb_id in knowledge_base_ids:
|
||||||
debug(f"处理知识库: {kb_id}")
|
debug(f"处理知识库: {kb_id}")
|
||||||
|
matched_triplets = await self._match_triplets(query, query_entities, userid, kb_id)
|
||||||
|
debug(f"知识库 {kb_id} 匹配三元组: {len(matched_triplets)} 条")
|
||||||
|
all_triplets.extend(matched_triplets)
|
||||||
|
timing_stats["triplet_matching"] = time.time() - triplet_match_start
|
||||||
|
debug(f"三元组匹配总耗时: {timing_stats['triplet_matching']:.3f} 秒")
|
||||||
|
|
||||||
results = collection.query(
|
if not all_triplets:
|
||||||
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'",
|
debug("未找到任何匹配的三元组")
|
||||||
output_fields=["document_id", "filename", "knowledge_base_id"],
|
return {"results": [], "timing": timing_stats}
|
||||||
limit=100 # 查询足够多的文档以支持后续过滤
|
|
||||||
)
|
|
||||||
if not results:
|
|
||||||
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
|
|
||||||
continue
|
|
||||||
documents.extend(results)
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
document_id = doc["document_id"]
|
|
||||||
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
|
|
||||||
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)} 条")
|
|
||||||
all_triplets.extend(matched_triplets)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
debug("未找到任何有效文档")
|
|
||||||
return []
|
|
||||||
|
|
||||||
info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
|
|
||||||
|
|
||||||
|
triplet_text_start = time.time()
|
||||||
triplet_texts = []
|
triplet_texts = []
|
||||||
for triplet in all_triplets:
|
for triplet in all_triplets:
|
||||||
head = triplet.get('head', '')
|
head = triplet.get('head', '')
|
||||||
@ -1029,13 +1041,18 @@ class MilvusConnection:
|
|||||||
combined_text += " [三元组] " + "; ".join(triplet_texts)
|
combined_text += " [三元组] " + "; ".join(triplet_texts)
|
||||||
debug(
|
debug(
|
||||||
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
|
||||||
|
timing_stats["triplet_text_combine"] = time.time() - triplet_text_start
|
||||||
|
debug(f"拼接三元组文本耗时: {timing_stats['triplet_text_combine']:.3f} 秒")
|
||||||
|
|
||||||
|
embedding_start = time.time()
|
||||||
embeddings = await self._get_embeddings([combined_text])
|
embeddings = await self._get_embeddings([combined_text])
|
||||||
query_vector = embeddings[0]
|
query_vector = embeddings[0]
|
||||||
debug(f"拼接文本向量维度: {len(query_vector)}")
|
debug(f"拼接文本向量维度: {len(query_vector)}")
|
||||||
|
timing_stats["embedding_generation"] = time.time() - embedding_start
|
||||||
|
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒")
|
||||||
|
|
||||||
|
search_start = time.time()
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||||
expr = f"userid == '{userid}' and ({kb_expr})"
|
expr = f"userid == '{userid}' and ({kb_expr})"
|
||||||
debug(f"搜索表达式: {expr}")
|
debug(f"搜索表达式: {expr}")
|
||||||
@ -1053,7 +1070,9 @@ class MilvusConnection:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"向量搜索失败: {str(e)}\n{exception()}")
|
error(f"向量搜索失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
timing_stats["vector_search"] = time.time() - search_start
|
||||||
|
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||||
|
|
||||||
search_results = []
|
search_results = []
|
||||||
for hits in results:
|
for hits in results:
|
||||||
@ -1078,31 +1097,40 @@ class MilvusConnection:
|
|||||||
|
|
||||||
unique_results = []
|
unique_results = []
|
||||||
seen_texts = set()
|
seen_texts = set()
|
||||||
|
dedup_start = time.time()
|
||||||
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
||||||
if result['text'] not in seen_texts:
|
if result['text'] not in seen_texts:
|
||||||
unique_results.append(result)
|
unique_results.append(result)
|
||||||
seen_texts.add(result['text'])
|
seen_texts.add(result['text'])
|
||||||
|
timing_stats["deduplication"] = time.time() - dedup_start
|
||||||
|
debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒")
|
||||||
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
||||||
|
|
||||||
if use_rerank and unique_results:
|
if use_rerank and unique_results:
|
||||||
|
rerank_start = time.time()
|
||||||
debug("开始重排序")
|
debug("开始重排序")
|
||||||
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit
|
unique_results = await self._rerank_results(combined_text, unique_results, limit)
|
||||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
|
timing_stats["reranking"] = time.time() - rerank_start
|
||||||
|
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||||
else:
|
else:
|
||||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||||
|
|
||||||
info(f"融合搜索完成,返回 {len(unique_results)} 条结果")
|
timing_stats["total_time"] = time.time() - start_time
|
||||||
return unique_results[:limit]
|
info(f"融合搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||||
|
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"融合搜索失败: {str(e)}\n{exception()}")
|
error(f"融合搜索失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
|
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
|
||||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
offset: int = 0, use_rerank: bool = True) -> Dict[str, Any]:
|
||||||
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
|
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
|
||||||
|
start_time = time.time() # 记录开始时间
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
|
timing_stats = {} # 记录各步骤耗时
|
||||||
try:
|
try:
|
||||||
info(
|
info(
|
||||||
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||||
@ -1133,22 +1161,27 @@ class MilvusConnection:
|
|||||||
|
|
||||||
if not utility.has_collection(collection_name):
|
if not utility.has_collection(collection_name):
|
||||||
debug(f"集合 {collection_name} 不存在")
|
debug(f"集合 {collection_name} 不存在")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
collection = Collection(collection_name)
|
collection = Collection(collection_name)
|
||||||
collection.load()
|
collection.load()
|
||||||
debug(f"加载集合: {collection_name}")
|
debug(f"加载集合: {collection_name}")
|
||||||
|
timing_stats["collection_load"] = time.time() - start_time
|
||||||
|
debug(f"集合加载耗时: {timing_stats['collection_load']:.3f} 秒")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||||
raise RuntimeError(f"加载集合失败: {str(e)}")
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
|
embedding_start = time.time()
|
||||||
embeddings = await self._get_embeddings([query])
|
embeddings = await self._get_embeddings([query])
|
||||||
query_vector = embeddings[0]
|
query_vector = embeddings[0]
|
||||||
debug(f"查询向量维度: {len(query_vector)}")
|
debug(f"查询向量维度: {len(query_vector)}")
|
||||||
|
timing_stats["embedding_generation"] = time.time() - embedding_start
|
||||||
|
debug(f"嵌入向量生成耗时: {timing_stats['embedding_generation']:.3f} 秒")
|
||||||
|
|
||||||
|
search_start = time.time()
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||||
expr = f"userid == '{userid}' and ({kb_id_expr})"
|
expr = f"userid == '{userid}' and ({kb_id_expr})"
|
||||||
debug(f"搜索表达式: {expr}")
|
debug(f"搜索表达式: {expr}")
|
||||||
@ -1166,7 +1199,9 @@ class MilvusConnection:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"搜索失败: {str(e)}\n{exception()}")
|
error(f"搜索失败: {str(e)}\n{exception()}")
|
||||||
raise RuntimeError(f"搜索失败: {str(e)}")
|
return {"results": [], "timing": timing_stats}
|
||||||
|
timing_stats["vector_search"] = time.time() - search_start
|
||||||
|
debug(f"向量搜索耗时: {timing_stats['vector_search']:.3f} 秒")
|
||||||
|
|
||||||
search_results = []
|
search_results = []
|
||||||
for hits in results:
|
for hits in results:
|
||||||
@ -1189,96 +1224,100 @@ class MilvusConnection:
|
|||||||
debug(
|
debug(
|
||||||
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
|
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
|
||||||
|
|
||||||
|
dedup_start = time.time()
|
||||||
unique_results = []
|
unique_results = []
|
||||||
seen_texts = set()
|
seen_texts = set()
|
||||||
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
|
||||||
if result['text'] not in seen_texts:
|
if result['text'] not in seen_texts:
|
||||||
unique_results.append(result)
|
unique_results.append(result)
|
||||||
seen_texts.add(result['text'])
|
seen_texts.add(result['text'])
|
||||||
|
timing_stats["deduplication"] = time.time() - dedup_start
|
||||||
|
debug(f"去重耗时: {timing_stats['deduplication']:.3f} 秒")
|
||||||
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
|
||||||
|
|
||||||
if use_rerank and unique_results:
|
if use_rerank and unique_results:
|
||||||
|
rerank_start = time.time()
|
||||||
debug("开始重排序")
|
debug("开始重排序")
|
||||||
unique_results = await self._rerank_results(query, unique_results, limit)
|
unique_results = await self._rerank_results(query, unique_results, limit)
|
||||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
|
timing_stats["reranking"] = time.time() - rerank_start
|
||||||
|
debug(f"重排序耗时: {timing_stats['reranking']:.3f} 秒")
|
||||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||||
else:
|
else:
|
||||||
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
|
||||||
|
|
||||||
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果")
|
timing_stats["total_time"] = time.time() - start_time
|
||||||
return unique_results[:limit]
|
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果,总耗时: {timing_stats['total_time']:.3f} 秒")
|
||||||
|
return {"results": unique_results[:limit], "timing": timing_stats}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"纯向量搜索失败: {str(e)}\n{exception()}")
|
error(f"纯向量搜索失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return {"results": [], "timing": timing_stats}
|
||||||
|
|
||||||
async def list_user_files(self, userid: str) -> List[Dict]:
|
async def _list_user_files(self, userid: str, db_type: str = "") -> Dict[str, List[Dict]]:
|
||||||
"""根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询"""
|
"""列出用户的所有知识库及其文件,按 knowledge_base_id 分组"""
|
||||||
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
info(f"开始查询用户文件列表: userid={userid}")
|
info(f"列出用户文件: userid={userid}, db_type={db_type}")
|
||||||
|
|
||||||
if not userid:
|
if not userid:
|
||||||
raise ValueError("userid 不能为空")
|
raise ValueError("userid 不能为空")
|
||||||
if "_" in userid:
|
if "_" in userid or (db_type and "_" in db_type):
|
||||||
raise ValueError("userid 不能包含下划线")
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
if len(userid) > 100:
|
if (db_type and len(db_type) > 100) or len(userid) > 100:
|
||||||
raise ValueError("userid 长度超出限制")
|
raise ValueError("userid 或 db_type 的长度超出限制")
|
||||||
|
|
||||||
collections = utility.list_collections()
|
if not utility.has_collection(collection_name):
|
||||||
collections = [c for c in collections if c.startswith("ragdb")]
|
debug(f"集合 {collection_name} 不存在")
|
||||||
if not collections:
|
return {}
|
||||||
debug("未找到任何 ragdb 开头的集合")
|
|
||||||
return []
|
|
||||||
debug(f"找到集合: {collections}")
|
|
||||||
|
|
||||||
file_list = []
|
try:
|
||||||
seen_files = set()
|
collection = Collection(collection_name)
|
||||||
for collection_name in collections:
|
collection.load()
|
||||||
db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else ""
|
debug(f"加载集合: {collection_name}")
|
||||||
debug(f"处理集合: {collection_name}, db_type={db_type}")
|
except Exception as e:
|
||||||
|
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
||||||
|
return {}
|
||||||
|
|
||||||
try:
|
expr = f"userid == '{userid}'"
|
||||||
collection = Collection(collection_name)
|
debug(f"查询表达式: {expr}")
|
||||||
collection.load()
|
|
||||||
debug(f"加载集合: {collection_name}")
|
|
||||||
except Exception as e:
|
|
||||||
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = collection.query(
|
results = collection.query(
|
||||||
expr=f"userid == '{userid}'",
|
expr=expr,
|
||||||
output_fields=["filename", "file_path", "upload_time", "file_type"],
|
output_fields=["document_id", "filename", "file_path", "upload_time", "file_type", "knowledge_base_id"],
|
||||||
limit=1000
|
limit=1000
|
||||||
)
|
)
|
||||||
debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块")
|
except Exception as e:
|
||||||
except Exception as e:
|
error(f"查询用户文件失败: {str(e)}\n{exception()}")
|
||||||
error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}")
|
return {}
|
||||||
continue
|
|
||||||
|
|
||||||
for result in results:
|
files_by_kb = {}
|
||||||
filename = result.get("filename")
|
seen_document_ids = set()
|
||||||
file_path = result.get("file_path")
|
for result in results:
|
||||||
upload_time = result.get("upload_time")
|
document_id = result.get("document_id")
|
||||||
file_type = result.get("file_type")
|
kb_id = result.get("knowledge_base_id")
|
||||||
if (filename, file_path) not in seen_files:
|
if document_id not in seen_document_ids:
|
||||||
seen_files.add((filename, file_path))
|
seen_document_ids.add(document_id)
|
||||||
file_list.append({
|
file_info = {
|
||||||
"filename": filename,
|
"document_id": document_id,
|
||||||
"file_path": file_path,
|
"filename": result.get("filename"),
|
||||||
"db_type": db_type,
|
"file_path": result.get("file_path"),
|
||||||
"upload_time": upload_time,
|
"upload_time": result.get("upload_time"),
|
||||||
"file_type": file_type
|
"file_type": result.get("file_type"),
|
||||||
})
|
"knowledge_base_id": kb_id
|
||||||
debug(
|
}
|
||||||
f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}")
|
if kb_id not in files_by_kb:
|
||||||
|
files_by_kb[kb_id] = []
|
||||||
|
files_by_kb[kb_id].append(file_info)
|
||||||
|
debug(f"找到文件: document_id={document_id}, filename={result.get('filename')}, knowledge_base_id={kb_id}")
|
||||||
|
|
||||||
info(f"返回 {len(file_list)} 个文件")
|
info(f"找到 {len(seen_document_ids)} 个文件,userid={userid}, 知识库数量={len(files_by_kb)}")
|
||||||
return sorted(file_list, key=lambda x: x["upload_time"], reverse=True)
|
return files_by_kb
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}")
|
error(f"列出用户文件失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
connection_register('Milvus', MilvusConnection)
|
connection_register('Milvus', MilvusConnection)
|
||||||
info("MilvusConnection registered")
|
info("MilvusConnection registered")
|
@ -1,25 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
import logging
|
from appPublic.log import debug, error, warning, info
|
||||||
|
from appPublic.worker import awaitify
|
||||||
from base_triple import BaseTripleExtractor, llm_register
|
from base_triple import BaseTripleExtractor, llm_register
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MRebelTripleExtractor(BaseTripleExtractor):
|
class MRebelTripleExtractor(BaseTripleExtractor):
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: str):
|
||||||
super().__init__(model_path)
|
super().__init__(model_path)
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Loading tokenizer from {model_path}")
|
debug(f"Loading tokenizer from {model_path}")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
logger.debug(f"Loading model from {model_path}")
|
debug(f"Loading model from {model_path}")
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||||||
self.device = self.use_mps_if_possible()
|
self.device = self.use_mps_if_possible()
|
||||||
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
|
self.triplet_id = self.tokenizer.convert_tokens_to_ids("<triplet>")
|
||||||
logger.debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
debug(f"Loaded mREBEL model, triplet_id: {self.triplet_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load mREBEL model: {str(e)}")
|
error(f"Failed to load mREBEL model: {str(e)}")
|
||||||
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
|
raise RuntimeError(f"Failed to load mREBEL model: {str(e)}")
|
||||||
|
|
||||||
self.gen_kwargs = {
|
self.gen_kwargs = {
|
||||||
@ -47,13 +47,13 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
current_chunk = sentence
|
current_chunk = sentence
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(current_chunk)
|
chunks.append(current_chunk)
|
||||||
logger.debug(f"Text split into: {len(chunks)} chunks")
|
debug(f"Text split into: {len(chunks)} chunks")
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def extract_triplets_typed(self, text: str) -> list:
|
def extract_triplets_typed(self, text: str) -> list:
|
||||||
"""Parse mREBEL generated text for triplets."""
|
"""Parse mREBEL generated text for triplets."""
|
||||||
triplets = []
|
triplets = []
|
||||||
logger.debug(f"Raw generated text: {text}")
|
debug(f"Raw generated text: {text}")
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
in_tag = False
|
in_tag = False
|
||||||
@ -77,7 +77,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
|
|
||||||
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
special_tokens = ["<s>", "<pad>", "</s>", "tp_XX", "__en__", "__zh__", "zh_CN"]
|
||||||
tokens = [t for t in tokens if t not in special_tokens and t]
|
tokens = [t for t in tokens if t not in special_tokens and t]
|
||||||
logger.debug(f"Processed tokens: {tokens}")
|
debug(f"Processed tokens: {tokens}")
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
@ -96,7 +96,7 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
'tail': entity2.strip(),
|
'tail': entity2.strip(),
|
||||||
'tail_type': type2
|
'tail_type': type2
|
||||||
})
|
})
|
||||||
logger.debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
debug(f"Added triplet: {entity1}({type1}) - {relation} - {entity2}({type2})")
|
||||||
i += 6
|
i += 6
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
@ -110,11 +110,11 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
raise ValueError("Text cannot be empty")
|
raise ValueError("Text cannot be empty")
|
||||||
|
|
||||||
text_chunks = self.split_document(text, max_chunk_size=150)
|
text_chunks = self.split_document(text, max_chunk_size=150)
|
||||||
logger.debug(f"Text split into {len(text_chunks)} chunks")
|
debug(f"Text split into {len(text_chunks)} chunks")
|
||||||
|
|
||||||
all_triplets = []
|
all_triplets = []
|
||||||
for i, chunk in enumerate(text_chunks):
|
for i, chunk in enumerate(text_chunks):
|
||||||
logger.debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
debug(f"Processing chunk {i + 1}/{len(text_chunks)}: {chunk[:50]}...")
|
||||||
|
|
||||||
model_inputs = self.tokenizer(
|
model_inputs = self.tokenizer(
|
||||||
chunk,
|
chunk,
|
||||||
@ -128,17 +128,17 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
generated_tokens = self.model.generate(
|
generated_tokens = self.model.generate(
|
||||||
model_inputs["input_ids"],
|
model_inputs["input_ids"],
|
||||||
attention_mask=model_inputs["attention_mask"],
|
attention_mask=model_inputs["attention_mask"],
|
||||||
**self.gen_kwargs,
|
**self.gen_kwargs
|
||||||
)
|
)
|
||||||
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
|
||||||
for idx, sentence in enumerate(decoded_preds):
|
for idx, sentence in enumerate(decoded_preds):
|
||||||
logger.debug(f"Chunk {i + 1} generated text: {sentence}")
|
debug(f"Chunk {i + 1} generated text: {sentence}")
|
||||||
triplets = self.extract_triplets_typed(sentence)
|
triplets = self.extract_triplets_typed(sentence)
|
||||||
if triplets:
|
if triplets:
|
||||||
logger.debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
debug(f"Chunk {i + 1} extracted {len(triplets)} triplets")
|
||||||
all_triplets.extend(triplets)
|
all_triplets.extend(triplets)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing chunk {i + 1}: {str(e)}")
|
warning(f"Error processing chunk {i + 1}: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unique_triplets = []
|
unique_triplets = []
|
||||||
@ -149,13 +149,12 @@ class MRebelTripleExtractor(BaseTripleExtractor):
|
|||||||
seen.add(identifier)
|
seen.add(identifier)
|
||||||
unique_triplets.append(t)
|
unique_triplets.append(t)
|
||||||
|
|
||||||
logger.info(f"Extracted {len(unique_triplets)} unique triplets")
|
info(f"Extracted {len(unique_triplets)} unique triplets")
|
||||||
return unique_triplets
|
return unique_triplets
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to extract triplets: {str(e)}")
|
error(f"Failed to extract triplets: {str(e)}")
|
||||||
import traceback
|
debug(f"Traceback: {traceback.format_exc()}")
|
||||||
logger.debug(traceback.format_exc())
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
llm_register("mrebel-large", MRebelTripleExtractor)
|
llm_register("mrebel-large", MRebelTripleExtractor)
|
Loading…
Reference in New Issue
Block a user