llmengine/llmengine/milvus_connection.py

1284 lines
65 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from appPublic.jsonConfig import getConfig
import os
from appPublic.log import debug, error, info, exception
import yaml
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from threading import Lock
from llmengine.base_connection import connection_register
from typing import Dict, List, Any
import aiohttp
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid
from datetime import datetime
from filetxt.loader import fileloader
from llmengine.kgc import KnowledgeGraph
import numpy as np
from py2neo import Graph
from scipy.spatial.distance import cosine
# 嵌入缓存
EMBED_CACHE = {}
class MilvusConnection:
_instance = None
_lock = Lock()
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(MilvusConnection, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
try:
config = getConfig()
self.db_path = config['milvus_db']
self.neo4j_uri = config['neo4j']['uri']
self.neo4j_user = config['neo4j']['user']
self.neo4j_password = config['neo4j']['password']
except KeyError as e:
error(f"配置文件缺少必要字段: {str(e)}\n{exception()}")
raise RuntimeError(f"配置文件缺少必要字段: {str(e)}")
self._initialize_connection()
self._initialized = True
info(f"MilvusConnection initialized with db_path: {self.db_path}")
def _initialize_connection(self):
"""初始化 Milvus 连接,确保单一连接"""
try:
db_dir = os.path.dirname(self.db_path)
if not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
debug(f"创建 Milvus 目录: {db_dir}")
if not os.access(db_dir, os.W_OK):
raise RuntimeError(f"Milvus 目录 {db_dir} 不可写")
if not connections.has_connection("default"):
connections.connect("default", uri=self.db_path)
debug(f"已连接到 Milvus Lite路径: {self.db_path}")
else:
debug("已存在 Milvus 连接,跳过重复连接")
except Exception as e:
error(f"连接 Milvus 失败: {str(e)}\n{exception()}")
raise RuntimeError(f"连接 Milvus 失败: {str(e)}")
async def handle_connection(self, action: str, params: Dict = None) -> Dict:
"""处理数据库操作"""
try:
debug(f"处理操作: action={action}, params={params}")
if not params:
params = {}
# 通用 db_type 验证
db_type = params.get("db_type", "")
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if db_type and "_" in db_type:
return {"status": "error", "message": "db_type 不能包含下划线", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if db_type and len(db_type) > 100:
return {"status": "error", "message": "db_type 的长度应小于 100", "collection_name": collection_name,
"document_id": "", "status_code": 400}
if action == "initialize":
return {"status": "success", "message": f"Milvus 连接已初始化,路径: {self.db_path}"}
elif action == "get_params":
return {"status": "success", "params": {"uri": self.db_path}}
elif action == "create_collection":
return await self._create_collection(db_type)
elif action == "delete_collection":
return await self._delete_collection(db_type)
elif action == "insert_document":
file_path = params.get("file_path", "")
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
if not file_path or not userid or not knowledge_base_id:
return {"status": "error", "message": "file_path、userid 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(knowledge_base_id) > 100:
return {"status": "error", "message": "knowledge_base_id 的长度应小于 100",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._insert_document(file_path, userid, knowledge_base_id, db_type)
elif action == "delete_document":
userid = params.get("userid", "")
filename = params.get("filename", "")
knowledge_base_id = params.get("knowledge_base_id", "")
if not userid or not filename or not knowledge_base_id:
return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_document(db_type, userid, filename, knowledge_base_id)
elif action == "delete_knowledge_base":
userid = params.get("userid", "")
knowledge_base_id = params.get("knowledge_base_id", "")
if not userid or not knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if "_" in userid or "_" in knowledge_base_id:
return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线",
"collection_name": collection_name, "document_id": "", "status_code": 400}
if len(userid) > 100 or len(knowledge_base_id) > 100:
return {"status": "error", "message": "userid 或 knowledge_base_id 的长度超出限制",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._delete_knowledge_base(db_type, userid, knowledge_base_id)
elif action == "fused_search":
query = params.get("query", "")
userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
limit = params.get("limit", 5)
if not query or not userid or not knowledge_base_ids:
return {
"status": "error",
"message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
"document_id": "",
"status_code": 400
}
if limit < 1 or limit > 16384:
return {
"status": "error",
"message": "limit 必须在 1 到 16384 之间",
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
"document_id": "",
"status_code": 400
}
return await self._fused_search(
query,
userid,
params.get("db_type", ""),
knowledge_base_ids,
limit,
params.get("offset", 0),
params.get("use_rerank", True)
)
elif action == "search_query":
query = params.get("query", "")
userid = params.get("userid", "")
limit = params.get("limit", "")
knowledge_base_ids = params.get("knowledge_base_ids", [])
if not query or not userid or not knowledge_base_ids:
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._search_query(
query,
userid,
db_type,
knowledge_base_ids,
limit,
params.get("offset", 0),
params.get("use_rerank", True)
)
elif action == "list_user_files":
userid = params.get("userid", "")
if not userid:
return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name,
"document_id": "", "status_code": 400}
return await self.list_user_files(userid)
else:
return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name,
"document_id": "", "status_code": 400}
except Exception as e:
error(f"处理操作失败: action={action}, 错误: {str(e)}")
return {
"status": "error",
"message": f"服务器错误: {str(e)}",
"collection_name": params.get("db_type", "ragdb") if params else "ragdb",
"document_id": "",
"status_code": 400
}
async def _create_collection(self, db_type: str = "") -> Dict:
"""创建 Milvus 集合"""
try:
# 根据 db_type 决定集合名称
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"集合名称: {collection_name}")
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True),
FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="knowledge_base_id", dtype=DataType.VARCHAR, max_length=100),
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=255),
FieldSchema(name="file_path", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="upload_time", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="file_type", dtype=DataType.VARCHAR, max_length=64),
]
schema = CollectionSchema(
fields=fields,
description="统一数据集合包含用户ID、知识库ID、document_id 和元数据字段",
auto_id=True,
primary_field="pk",
)
if utility.has_collection(collection_name):
try:
collection = Collection(collection_name)
existing_schema = collection.schema
expected_fields = {f.name for f in fields}
actual_fields = {f.name for f in existing_schema.fields}
vector_field = next((f for f in existing_schema.fields if f.name == "vector"), None)
schema_compatible = False
if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR:
dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None
schema_compatible = dim == 1024
debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, "
f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else ''}, "
f"dim={dim if dim is not None else '未定义'}")
if not schema_compatible:
debug(f"集合 {collection_name} 的 schema 不兼容,原因: "
f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or ''}, "
f"vector_field: {vector_field is not None}, "
f"dtype: {vector_field.dtype if vector_field else ''}, "
f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}")
utility.drop_collection(collection_name)
else:
collection.load()
debug(f"集合 {collection_name} 已存在并加载成功")
return {
"status": "success",
"collection_name": collection_name,
"message": f"集合 {collection_name} 已存在"
}
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e)
}
try:
collection = Collection(collection_name, schema)
collection.create_index(
field_name="vector",
index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"}
)
for field in ["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", "file_type"]:
collection.create_index(
field_name=field,
index_params={"index_type": "INVERTED"}
)
collection.load()
debug(f"成功创建并加载集合: {collection_name}")
return {
"status": "success",
"collection_name": collection_name,
"message": f"集合 {collection_name} 创建成功"
}
except Exception as e:
error(f"创建集合 {collection_name} 失败: {str(e)}\n{exception()}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e)
}
except Exception as e:
error(f"创建集合失败: {str(e)}")
return {
"status": "error",
"collection_name":collection_name,
"message": str(e)
}
async def _delete_collection(self, db_type: str = "") -> Dict:
"""删除 Milvus 集合"""
try:
# 根据 db_type 决定集合名称
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
if len(collection_name) > 255:
raise ValueError(f"集合名称 {collection_name} 超过 255 个字符")
if db_type and "_" in db_type:
raise ValueError("db_type 不能包含下划线")
if db_type and len(db_type) > 100:
raise ValueError("db_type 的长度应小于 100")
debug(f"集合名称: {collection_name}")
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return {
"status": "success",
"collection_name": collection_name,
"message": f"集合 {collection_name} 不存在,无需删除"
}
try:
utility.drop_collection(collection_name)
debug(f"成功删除集合: {collection_name}")
return {
"status": "success",
"collection_name": collection_name,
"message": f"集合 {collection_name} 删除成功"
}
except Exception as e:
error(f"删除集合 {collection_name} 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e)
}
except Exception as e:
error(f"删除集合失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"message": str(e)
}
async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> Dict[
str, Any]:
"""将文档插入 Milvus 并抽取三元组到 Neo4j"""
document_id = str(uuid.uuid4())
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
debug(
f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}')
try:
if not os.path.exists(file_path):
raise ValueError(f"文件 {file_path} 不存在")
supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'}
ext = file_path.rsplit('.', 1)[1].lower() if '.' in file_path else ''
if ext not in supported_formats:
raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}")
info(f"生成 document_id: {document_id} for file: {file_path}")
debug(f"加载文件: {file_path}")
text = fileloader(file_path)
if not text or not text.strip():
raise ValueError(f"文件 {file_path} 加载为空")
document = Document(page_content=text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
)
debug("开始分片文件内容")
chunks = text_splitter.split_documents([document])
if not chunks:
raise ValueError(f"文件 {file_path} 未生成任何文档块")
debug(f"文件 {file_path} 分割为 {len(chunks)} 个文档块")
filename = os.path.basename(file_path).rsplit('.', 1)[0]
upload_time = datetime.now().isoformat()
documents = []
for i, chunk in enumerate(chunks):
chunk.metadata.update({
'userid': userid,
'knowledge_base_id': knowledge_base_id,
'document_id': document_id,
'filename': filename + '.' + ext,
'file_path': file_path,
'upload_time': upload_time,
'file_type': ext,
})
documents.append(chunk)
debug(f"文档块 {i} 元数据: {chunk.metadata}")
debug(f"确保集合 {collection_name} 存在")
create_result = await self._create_collection(db_type)
if create_result["status"] == "error":
raise RuntimeError(f"集合创建失败: {create_result['message']}")
debug("调用嵌入服务生成向量")
texts = [doc.page_content for doc in documents]
embeddings = await self._get_embeddings(texts)
await self._insert_to_milvus(collection_name, documents, embeddings)
info(f"成功插入 {len(documents)} 个文档块到 {collection_name}")
debug("调用三元组抽取服务")
try:
triples = await self._extract_triples(text)
if triples:
debug(f"抽取到 {len(triples)} 个三元组,插入 Neo4j")
kg = KnowledgeGraph(triples=triples, document_id=document_id, knowledge_base_id=knowledge_base_id, userid=userid)
kg.create_graphnodes()
kg.create_graphrels()
kg.export_data()
info(f"文件 {file_path} 三元组成功插入 Neo4j")
else:
debug(f"文件 {file_path} 未抽取到三元组")
except Exception as e:
debug(f"处理三元组失败: {str(e)}")
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"message": f"文件 {file_path} 成功嵌入,但三元组处理失败: {str(e)}",
"status_code": 200
}
return {
"status": "success",
"document_id": document_id,
"collection_name": collection_name,
"message": f"文件 {file_path} 成功嵌入并处理三元组",
"status_code": 200
}
except Exception as e:
error(f"插入文档失败: {str(e)}")
return {
"status": "error",
"document_id": document_id,
"collection_name": collection_name,
"message": str(e),
"status_code": 400
}
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用嵌入服务获取文本的向量,带缓存"""
try:
# 检查缓存
uncached_texts = [text for text in texts if text not in EMBED_CACHE]
if uncached_texts:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9998/v1/embeddings",
headers={"Content-Type": "application/json"},
json={"input": uncached_texts}
) as response:
if response.status != 200:
error(f"嵌入服务调用失败,状态码: {response.status}")
raise RuntimeError(f"嵌入服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"嵌入服务响应格式错误: {result}")
raise RuntimeError("嵌入服务响应格式错误")
embeddings = [item["embedding"] for item in result["data"]]
for text, embedding in zip(uncached_texts, embeddings):
EMBED_CACHE[text] = np.array(embedding) / np.linalg.norm(embedding)
debug(f"成功获取 {len(embeddings)} 个新嵌入向量,缓存大小: {len(EMBED_CACHE)}")
# 返回缓存中的嵌入
return [EMBED_CACHE[text] for text in texts]
except Exception as e:
error(f"嵌入服务调用失败: {str(e)}\n{exception()}")
raise RuntimeError(f"嵌入服务调用失败: {str(e)}")
async def _extract_triples(self, text: str) -> List[Dict]:
"""调用三元组抽取服务"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9991/v1/triples",
headers={"Content-Type": "application/json; charset=utf-8"},
json={"text": text}
) as response:
if response.status != 200:
error(f"三元组抽取服务调用失败,状态码: {response.status}")
raise RuntimeError(f"三元组抽取服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"三元组抽取服务响应格式错误: {result}")
raise RuntimeError("三元组抽取服务响应格式错误")
triples = result["data"]
debug(f"成功抽取 {len(triples)} 个三元组")
return triples
except Exception as e:
error(f"三元组抽取服务调用失败: {str(e)}\n{exception()}")
raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}")
async def _insert_to_milvus(self, collection_name: str, documents: List[Document],
embeddings: List[List[float]]) -> None:
"""将文档和嵌入向量插入 Milvus 集合"""
try:
if not connections.has_connection("default"):
self._initialize_connection()
collection = Collection(collection_name)
collection.load()
data = {
"userid": [doc.metadata["userid"] for doc in documents],
"knowledge_base_id": [doc.metadata["knowledge_base_id"] for doc in documents],
"document_id": [doc.metadata["document_id"] for doc in documents],
"text": [doc.page_content for doc in documents],
"vector": embeddings,
"filename": [doc.metadata["filename"] for doc in documents],
"file_path": [doc.metadata["file_path"] for doc in documents],
"upload_time": [doc.metadata["upload_time"] for doc in documents],
"file_type": [doc.metadata["file_type"] for doc in documents],
}
collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"])
collection.flush()
debug(f"成功插入 {len(documents)} 个文档到集合 {collection_name}")
except Exception as e:
error(f"插入 Milvus 失败: {str(e)}")
raise RuntimeError(f"插入 Milvus 失败: {str(e)}")
async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[
str, Any]:
"""删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"集合 {collection_name} 不存在,无需删除",
"status_code": 200
}
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"加载集合失败: {str(e)}",
"status_code": 400
}
expr = f"userid == '{userid}' and filename == '{filename}' and knowledge_base_id == '{knowledge_base_id}'"
debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=expr,
output_fields=["document_id"],
limit=1000
)
if not results:
debug(
f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录,无需删除",
"status_code": 200
}
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
except Exception as e:
error(f"查询 document_id 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"查询失败: {str(e)}",
"status_code": 400
}
total_deleted = 0
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
for doc_id in document_ids:
try:
# 删除 Milvus 记录
delete_expr = f"document_id == '{doc_id}'"
debug(f"删除表达式: {delete_expr}")
delete_result = collection.delete(delete_expr)
deleted_count = delete_result.delete_count
total_deleted += deleted_count
info(f"成功删除 document_id={doc_id}{deleted_count} 条 Milvus 记录")
# 删除 Neo4j 三元组
try:
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
query = """
MATCH (n {document_id: $document_id})
OPTIONAL MATCH (n)-[r {document_id: $document_id}]->()
WITH collect(r) AS rels, collect(n) AS nodes
FOREACH (r IN rels | DELETE r)
FOREACH (n IN nodes | DELETE n)
RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types
"""
result = graph.run(query, document_id=doc_id).data()
nodes_deleted = result[0]['node_count'] if result else 0
rels_deleted = result[0]['rel_count'] if result else 0
rel_types = result[0]['rel_types'] if result else []
info(
f"成功删除 document_id={doc_id}{nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}")
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
except Exception as e:
error(f"删除 document_id={doc_id} 的 Neo4j 三元组失败: {str(e)}")
continue
except Exception as e:
error(f"删除 document_id={doc_id} 的 Milvus 记录失败: {str(e)}")
continue
if total_deleted == 0:
debug(
f"没有删除任何 Milvus 记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"没有删除任何记录userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}",
"status_code": 200
}
info(
f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": ",".join(document_ids),
"message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}",
"status_code": 200
}
except Exception as e:
error(f"删除文档失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"删除文档失败: {str(e)}",
"status_code": 400
}
async def _delete_knowledge_base(self, db_type: str, userid: str, knowledge_base_id: str) -> Dict[str, Any]:
"""删除用户的整个知识库,包括 Milvus 和 Neo4j 中的记录"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"集合 {collection_name} 不存在,无需删除",
"status_code": 200
}
try:
collection = Collection(collection_name)
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"加载集合失败: {str(e)}",
"status_code": 400
}
# 查询 Milvus 中的 document_id 列表
expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'"
debug(f"查询表达式: {expr}")
try:
results = collection.query(
expr=expr,
output_fields=["document_id"],
limit=1000
)
if not results:
debug(f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录")
# 即使 Milvus 没有记录,仍尝试删除 Neo4j 数据
else:
document_ids = list(set(result["document_id"] for result in results if "document_id" in result))
debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}")
except Exception as e:
error(f"查询 document_id 失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"查询失败: {str(e)}",
"status_code": 400
}
# 删除 Milvus 记录
total_deleted = 0
document_ids = []
if results:
try:
delete_expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'"
debug(f"删除表达式: {delete_expr}")
delete_result = collection.delete(delete_expr)
total_deleted = delete_result.delete_count
document_ids = [result["document_id"] for result in results if "document_id" in result]
info(f"成功删除 {total_deleted} 条 Milvus 记录")
except Exception as e:
error(f"删除 Milvus 记录失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"删除 Milvus 记录失败: {str(e)}",
"status_code": 400
}
# 删除 Neo4j 数据
neo4j_deleted_nodes = 0
neo4j_deleted_rels = 0
try:
debug(f"尝试连接 Neo4j: uri={self.neo4j_uri}, user={self.neo4j_user}")
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
debug("Neo4j 连接成功")
query = """
MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id})
OPTIONAL MATCH (n)-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->()
WITH collect(r) AS rels, collect(n) AS nodes
FOREACH (r IN rels | DELETE r)
FOREACH (n IN nodes | DELETE n)
RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types
"""
result = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id).data()
nodes_deleted = result[0]['node_count'] if result else 0
rels_deleted = result[0]['rel_count'] if result else 0
rel_types = result[0]['rel_types'] if result else []
neo4j_deleted_nodes += nodes_deleted
neo4j_deleted_rels += rels_deleted
info(f"成功删除 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}")
except Exception as e:
error(f"删除 Neo4j 数据失败: {str(e)}")
# 继续返回结果,即使 Neo4j 删除失败
return {
"status": "success",
"collection_name": collection_name,
"document_id": ",".join(document_ids) if document_ids else "",
"message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}",
"status_code": 200
}
if total_deleted == 0 and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0:
debug(f"没有删除任何记录userid={userid}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": "",
"message": f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录,无需删除",
"status_code": 200
}
info(
f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系userid={userid}, knowledge_base_id={knowledge_base_id}")
return {
"status": "success",
"collection_name": collection_name,
"document_id": ",".join(document_ids) if document_ids else "",
"message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系userid={userid}, knowledge_base_id={knowledge_base_id}",
"status_code": 200
}
except Exception as e:
error(f"删除知识库失败: {str(e)}")
return {
"status": "error",
"collection_name": collection_name,
"document_id": "",
"message": f"删除知识库失败: {str(e)}",
"status_code": 400
}
async def _extract_entities(self, query: str) -> List[str]:
"""调用实体识别服务"""
try:
if not query:
raise ValueError("查询文本不能为空")
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9990/v1/entities",
headers={"Content-Type": "application/json"},
json={"query": query}
) as response:
if response.status != 200:
error(f"实体识别服务调用失败,状态码: {response.status}")
raise RuntimeError(f"实体识别服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "list" or not result.get("data"):
error(f"实体识别服务响应格式错误: {result}")
raise RuntimeError("实体识别服务响应格式错误")
entities = result["data"]
unique_entities = list(dict.fromkeys(entities)) # 去重
debug(f"成功提取 {len(unique_entities)} 个唯一实体: {unique_entities}")
return unique_entities
except Exception as e:
error(f"实体识别服务调用失败: {str(e)}\n{exception()}")
return []
async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]:
"""匹配查询实体与 Neo4j 中的三元组"""
matched_triplets = []
ENTITY_SIMILARITY_THRESHOLD = 0.8
try:
graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password))
debug(f"已连接到 Neo4j: {self.neo4j_uri}")
matched_names = set()
for entity in query_entities:
normalized_entity = entity.lower().strip()
query = """
MATCH (n {document_id: $document_id})
WHERE toLower(n.name) CONTAINS $entity
OR apoc.text.levenshteinSimilarity(toLower(n.name), $entity) > 0.7
RETURN n.name, apoc.text.levenshteinSimilarity(toLower(n.name), $entity) AS sim
ORDER BY sim DESC
LIMIT 100
"""
try:
results = graph.run(query, document_id=document_id, entity=normalized_entity).data()
for record in results:
matched_names.add(record['n.name'])
debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})")
except Exception as e:
debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}")
continue
triplets = []
if matched_names:
query = """
MATCH (h {document_id: $document_id})-[r]->(t {document_id: $document_id})
WHERE h.name IN $matched_names OR t.name IN $matched_names
RETURN h.name AS head, r.name AS type, t.name AS tail
LIMIT 100
"""
try:
results = graph.run(query, document_id=document_id, matched_names=list(matched_names)).data()
seen = set()
for record in results:
head, type_, tail = record['head'], record['type'], record['tail']
triplet_key = (head.lower(), type_.lower(), tail.lower())
if triplet_key not in seen:
seen.add(triplet_key)
triplets.append({
'head': head,
'type': type_,
'tail': tail,
'head_type': '',
'tail_type': ''
})
debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}")
except Exception as e:
error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}")
return []
if not triplets:
debug(f"文档 document_id={document_id} 无匹配三元组")
return []
texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets]
embeddings = await self._get_embeddings(texts_to_embed)
entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)}
head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)}
tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)}
debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails")
for entity in query_entities:
entity_vec = entity_vectors[entity]
for d_triplet in triplets:
d_head_vec = head_vectors[d_triplet['head']]
d_tail_vec = tail_vectors[d_triplet['tail']]
head_similarity = 1 - cosine(entity_vec, d_head_vec)
tail_similarity = 1 - cosine(entity_vec, d_tail_vec)
if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD:
matched_triplets.append(d_triplet)
debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} "
f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})")
unique_matched = []
seen = set()
for t in matched_triplets:
identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower())
if identifier not in seen:
seen.add(identifier)
unique_matched.append(t)
info(f"找到 {len(unique_matched)} 个匹配的三元组")
return unique_matched
except Exception as e:
error(f"匹配三元组失败: {str(e)}\n{exception()}")
return []
async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]:
"""调用重排序服务"""
try:
if not results:
debug("无结果需要重排序")
return results
if not isinstance(top_n, int) or top_n < 1:
debug(f"无效的 top_n 参数: {top_n},使用 len(results)={len(results)}")
top_n = len(results)
else:
top_n = min(top_n, len(results))
debug(f"重排序 top_n={top_n}, 原始结果数={len(results)}")
documents = [result["text"] for result in results]
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:9997/v1/rerank",
headers={"Content-Type": "application/json"},
json={
"model": "rerank-001",
"query": query,
"documents": documents,
"top_n": top_n
}
) as response:
if response.status != 200:
error(f"重排序服务调用失败,状态码: {response.status}")
raise RuntimeError(f"重排序服务调用失败: {response.status}")
result = await response.json()
if result.get("object") != "rerank.result" or not result.get("data"):
error(f"重排序服务响应格式错误: {result}")
raise RuntimeError("重排序服务响应格式错误")
rerank_data = result["data"]
reranked_results = []
for item in rerank_data:
index = item["index"]
if index < len(results):
results[index]["rerank_score"] = item["relevance_score"]
reranked_results.append(results[index])
debug(f"成功重排序 {len(reranked_results)} 条结果")
return reranked_results[:top_n]
except Exception as e:
error(f"重排序服务调用失败: {str(e)}\n{exception()}")
return results
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query or not userid or not knowledge_base_ids:
raise ValueError("query、userid 和 knowledge_base_ids 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("db_type 或 userid 的长度超出限制")
if limit < 1 or limit > 16384 or offset < 0:
raise ValueError("limit 必须在 1 到 16384 之间offset 必须大于或等于 0")
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return []
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
return []
query_entities = await self._extract_entities(query)
debug(f"提取实体: {query_entities}")
documents = []
all_triplets = []
for kb_id in knowledge_base_ids:
debug(f"处理知识库: {kb_id}")
results = collection.query(
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'",
output_fields=["document_id", "filename", "knowledge_base_id"],
limit=100 # 查询足够多的文档以支持后续过滤
)
if not results:
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
continue
documents.extend(results)
for doc in results:
document_id = doc["document_id"]
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)}")
all_triplets.extend(matched_triplets)
if not documents:
debug("未找到任何有效文档")
return []
info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}")
triplet_texts = []
for triplet in all_triplets:
head = triplet.get('head', '')
type_ = triplet.get('type', '')
tail = triplet.get('tail', '')
if head and type_ and tail:
triplet_texts.append(f"{head} {type_} {tail}")
else:
debug(f"无效三元组: {triplet}")
combined_text = query
if triplet_texts:
combined_text += " [三元组] " + "; ".join(triplet_texts)
debug(
f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})")
embeddings = await self._get_embeddings([combined_text])
query_vector = embeddings[0]
debug(f"拼接文本向量维度: {len(query_vector)}")
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_expr})"
debug(f"搜索表达式: {expr}")
try:
results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=100,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
"file_type"],
offset=offset
)
except Exception as e:
error(f"向量搜索失败: {str(e)}\n{exception()}")
return []
search_results = []
for hits in results:
for hit in hits:
metadata = {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": "fused_query_with_triplets",
"metadata": metadata
}
search_results.append(result)
debug(
f"搜索命中: text={result['text'][:100]}..., distance={hit.distance}, source={result['source']}")
unique_results = []
seen_texts = set()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts:
unique_results.append(result)
seen_texts.add(result['text'])
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results:
debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"融合搜索完成,返回 {len(unique_results)} 条结果")
return unique_results[:limit]
except Exception as e:
error(f"融合搜索失败: {str(e)}\n{exception()}")
return []
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try:
info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query:
raise ValueError("查询文本不能为空")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("userid 或 db_type 的长度超出限制")
if limit <= 0 or limit > 16384:
raise ValueError("limit 必须在 1 到 16384 之间")
if offset < 0:
raise ValueError("offset 不能为负数")
if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids:
raise ValueError("knowledge_base_ids 不能为空")
for kb_id in knowledge_base_ids:
if not isinstance(kb_id, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
if len(kb_id) > 100:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
if "_" in kb_id:
raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}")
if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在")
return []
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
raise RuntimeError(f"加载集合失败: {str(e)}")
embeddings = await self._get_embeddings([query])
query_vector = embeddings[0]
debug(f"查询向量维度: {len(query_vector)}")
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
expr = f"userid == '{userid}' and ({kb_id_expr})"
debug(f"搜索表达式: {expr}")
try:
results = collection.search(
data=[query_vector],
anns_field="vector",
param=search_params,
limit=100,
expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
"file_type"],
offset=offset
)
except Exception as e:
error(f"搜索失败: {str(e)}\n{exception()}")
raise RuntimeError(f"搜索失败: {str(e)}")
search_results = []
for hits in results:
for hit in hits:
metadata = {
"userid": hit.entity.get("userid"),
"document_id": hit.entity.get("document_id"),
"filename": hit.entity.get("filename"),
"file_path": hit.entity.get("file_path"),
"upload_time": hit.entity.get("upload_time"),
"file_type": hit.entity.get("file_type")
}
result = {
"text": hit.entity.get("text"),
"distance": hit.distance,
"source": "vector_query",
"metadata": metadata
}
search_results.append(result)
debug(
f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}")
unique_results = []
seen_texts = set()
for result in sorted(search_results, key=lambda x: x['distance'], reverse=True):
if result['text'] not in seen_texts:
unique_results.append(result)
seen_texts.add(result['text'])
info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})")
if use_rerank and unique_results:
debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else:
unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results]
info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果")
return unique_results[:limit]
except Exception as e:
error(f"纯向量搜索失败: {str(e)}\n{exception()}")
return []
async def list_user_files(self, userid: str) -> List[Dict]:
"""根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询"""
try:
info(f"开始查询用户文件列表: userid={userid}")
if not userid:
raise ValueError("userid 不能为空")
if "_" in userid:
raise ValueError("userid 不能包含下划线")
if len(userid) > 100:
raise ValueError("userid 长度超出限制")
collections = utility.list_collections()
collections = [c for c in collections if c.startswith("ragdb")]
if not collections:
debug("未找到任何 ragdb 开头的集合")
return []
debug(f"找到集合: {collections}")
file_list = []
seen_files = set()
for collection_name in collections:
db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else ""
debug(f"处理集合: {collection_name}, db_type={db_type}")
try:
collection = Collection(collection_name)
collection.load()
debug(f"加载集合: {collection_name}")
except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}")
continue
try:
results = collection.query(
expr=f"userid == '{userid}'",
output_fields=["filename", "file_path", "upload_time", "file_type"],
limit=1000
)
debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块")
except Exception as e:
error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}")
continue
for result in results:
filename = result.get("filename")
file_path = result.get("file_path")
upload_time = result.get("upload_time")
file_type = result.get("file_type")
if (filename, file_path) not in seen_files:
seen_files.add((filename, file_path))
file_list.append({
"filename": filename,
"file_path": file_path,
"db_type": db_type,
"upload_time": upload_time,
"file_type": file_type
})
debug(
f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}")
info(f"返回 {len(file_list)} 个文件")
return sorted(file_list, key=lambda x: x["upload_time"], reverse=True)
except Exception as e:
error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}")
return []
connection_register('Milvus', MilvusConnection)
info("MilvusConnection registered")