From 7b6a0aaa62726167494b081cc89838d0fe1b561d Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Thu, 3 Jul 2025 18:34:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0/=E5=88=A0=E9=99=A4=E9=9B=86?= =?UTF-8?q?=E5=90=88=EF=BC=8C=E6=8F=92=E5=85=A5/=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llmengine/connection.py | 232 +++++---- llmengine/milvus_connection.py | 838 +++++++++++++++++++-------------- 2 files changed, 633 insertions(+), 437 deletions(-) diff --git a/llmengine/connection.py b/llmengine/connection.py index 5c52c0c..2ead782 100644 --- a/llmengine/connection.py +++ b/llmengine/connection.py @@ -1,18 +1,15 @@ import milvus_connection from traceback import format_exc import argparse -import logging from aiohttp import web from llmengine.base_connection import get_connection_class from appPublic.registerfunction import RegisterFunction -from appPublic.log import debug, exception +from appPublic.log import debug, error, info, exception from ahserver.serverenv import ServerEnv from ahserver.webapp import webserver import os import json -logger = logging.getLogger(__name__) - helptext = """Milvus Connection Service API (using pymilvus Collection API): 1. Create Collection Endpoint: @@ -20,48 +17,54 @@ path: /v1/createcollection method: POST headers: {"Content-Type": "application/json"} data: { - "db_type": "textdb" + "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb } response: -- Success: HTTP 200, {"status": "success", "collection_name": "ragdb_textdb", "message": "集合 ragdb_textdb 创建成功"} -- Error: HTTP 400, {"status": "error", "collection_name": "ragdb_textdb", "message": ""} +- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 创建成功"} +- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": ""} 2. Delete Collection Endpoint: path: /v1/deletecollection method: POST headers: {"Content-Type": "application/json"} data: { - "db_type": "textdb" + "db_type": "textdb" // 可选,若不提供则删除默认集合 ragdb } response: -- Success: HTTP 200, {"status": "success", "collection_name": "ragdb_textdb", "message": "集合 ragdb_textdb 删除成功"} -- Error: HTTP 400, {"status": "error", "collection_name": "ragdb_textdb", "message": ""} +- Success: HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 删除成功"} +- Success (collection does not exist): HTTP 200, {"status": "success", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 ragdb 或 ragdb_textdb 不存在,无需删除"} +- Error: HTTP 400, {"status": "error", "collection_name": "ragdb" or "ragdb_textdb", "message": ""} 3. Insert File Endpoint: path: /v1/insertfile method: POST headers: {"Content-Type": "application/json"} data: { - "file_path": "/path/to/file.txt", - "userid": "user1", - "db_type": "textdb" + "file_path": "/path/to/file.txt", // 必填,文件路径 + "userid": "user123", // 必填,用户 ID + "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb + "knowledge_base_id": "kb123" // 必填,知识库 ID } response: -- Success: HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb_textdb", "message": "文件 /path/to/file.txt 成功嵌入并处理三元组"} -- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb_textdb", "message": ""} +- Success: HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "文件 成功嵌入并处理三元组", "status_code": 200} +- Success (triples failed): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "文件 成功嵌入,但三元组处理失败: ", "status_code": 200} +- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} -4. Delete File Endpoint: +4. Delete Document Endpoint: path: /v1/deletefile method: POST headers: {"Content-Type": "application/json"} data: { - "db_type": "textdb", - "userid": "user1", - "filename": "test.txt" + "userid": "user123", // 必填,用户 ID + "filename": "file.txt", // 必填,文件名 + "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb + "knowledge_base_id": "kb123" // 必填,知识库 ID } response: -- Success: HTTP 200, {"status": "success", "collection_name": "ragdb_textdb", "message": "成功删除 X 条记录,userid=user1, filename=test.txt"} -- Error: HTTP 400, {"status": "error", "collection_name": "ragdb_textdb", "message": ""} +- Success: HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 条 Milvus 记录, 个 Neo4j 节点, 个 Neo4j 关系,userid=, filename=", "status_code": 200} +- Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=, filename=, knowledge_base_id= 的记录,无需删除", "status_code": 200} +- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 不存在,无需删除", "status_code": 200} +- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} 5. Fused Search Query Endpoint: path: /v1/fusedsearchquery @@ -70,7 +73,7 @@ headers: {"Content-Type": "application/json"} data: { "query": "苹果公司在北京开设新店", "userid": "user1", - "db_type": "textdb", + "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb "file_paths": ["/path/to/file.txt"], "limit": 5, "offset": 0, @@ -103,7 +106,7 @@ headers: {"Content-Type": "application/json"} data: { "query": "知识图谱的知识融合是什么?", "userid": "user1", - "db_type": "textdb", + "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb "file_paths": ["/path/to/file.txt"], "limit": 5, "offset": 0, @@ -165,6 +168,21 @@ response: path: /v1/docs method: GET response: This help text + +10.Delete Knowledge Base Endpoint: +path: /v1/deleteknowledgebase +method: POST +headers: {"Content-Type": "application/json"} +data: { + "userid": "user123", // 必填,用户 ID + "knowledge_base_id": "kb123",// 必填,知识库 ID + "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb +} +response: +- Success: HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 条 Milvus 记录, 个 Neo4j 节点, 个 Neo4j 关系,userid=, knowledge_base_id=", "status_code": 200} +- Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=, knowledge_base_id= 的记录,无需删除", "status_code": 200} +- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 不存在,无需删除", "status_code": 200} +- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} """ def init(): @@ -173,6 +191,7 @@ def init(): rf.register('deletecollection', delete_collection) rf.register('insertfile', insert_file) rf.register('deletefile', delete_file) + rf.register('deleteknowledgebase', delete_knowledge_base) rf.register('fusedsearchquery', fused_search_query) rf.register('searchquery', search_query) rf.register('listuserfiles', list_user_files) @@ -192,22 +211,17 @@ async def create_collection(request, params_kw, *params, **kw): debug(f'{params_kw=}') se = ServerEnv() engine = se.engine - db_type = params_kw.get('db_type') - if db_type is None: - debug(f'db_type 未提供') - return web.json_response({ - "status": "error", - "message": "db_type 参数未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + db_type = params_kw.get('db_type', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: result = await engine.handle_connection("create_collection", {"db_type": db_type}) debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'创建集合失败: {str(e)}') + error(f'创建集合失败: {str(e)}') return web.json_response({ "status": "error", - "collection_name": f"ragdb_{db_type}", + "collection_name": collection_name, "message": str(e) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) @@ -215,82 +229,128 @@ async def delete_collection(request, params_kw, *params, **kw): debug(f'{params_kw=}') se = ServerEnv() engine = se.engine - db_type = params_kw.get('db_type') - if db_type is None: - debug(f'db_type 未提供') - return web.json_response({ - "status": "error", - "message": "db_type 参数未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + db_type = params_kw.get('db_type', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: result = await engine.handle_connection("delete_collection", {"db_type": db_type}) debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'删除集合失败: {str(e)}') + error(f'删除集合失败: {str(e)}') return web.json_response({ "status": "error", - "collection_name": f"ragdb_{db_type}", + "collection_name": collection_name, "message": str(e) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + async def insert_file(request, params_kw, *params, **kw): - debug(f'{params_kw=}') + debug(f'Received params: {params_kw=}') se = ServerEnv() engine = se.engine - file_path = params_kw.get('file_path') - userid = params_kw.get('userid') - db_type = params_kw.get('db_type') - if not all([file_path, userid, db_type]): - debug(f'file_path, userid 或 db_type 未提供') - return web.json_response({ - "status": "error", - "message": "file_path, userid 或 db_type 未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + file_path = params_kw.get('file_path', '') + userid = params_kw.get('userid', '') + db_type = params_kw.get('db_type', '') + knowledge_base_id = params_kw.get('knowledge_base_id', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: + # 仅检查必填字段是否存在 + required_fields = ['file_path', 'userid', 'knowledge_base_id'] + missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] + if missing_fields: + raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + + debug( + f'Calling insert_document with: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') result = await engine.handle_connection("insert_document", { "file_path": file_path, "userid": userid, - "db_type": db_type + "db_type": db_type, + "knowledge_base_id": knowledge_base_id }) - debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + debug(f'Insert result: {result=}') + # 根据 result 的 status 设置 HTTP 状态码 + status = 200 if result.get("status") == "success" else 400 + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) except Exception as e: - debug(f'插入文件失败: {str(e)}') + error(f'插入文件失败: {str(e)}') return web.json_response({ "status": "error", + "collection_name": collection_name, "document_id": "", - "collection_name": f"ragdb_{db_type}", "message": str(e) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + async def delete_file(request, params_kw, *params, **kw): - debug(f'{params_kw=}') + debug(f'Received delete_file params: {params_kw=}') se = ServerEnv() engine = se.engine - db_type = params_kw.get('db_type') - userid = params_kw.get('userid') - filename = params_kw.get('filename') - if not all([db_type, userid, filename]): - debug(f'db_type, userid 或 filename 未提供') - return web.json_response({ - "status": "error", - "message": "db_type, userid 或 filename 未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + userid = params_kw.get('userid', '') + filename = params_kw.get('filename', '') + db_type = params_kw.get('db_type', '') + knowledge_base_id = params_kw.get('knowledge_base_id', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: + required_fields = ['userid', 'filename', 'knowledge_base_id'] + missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] + if missing_fields: + raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + + debug( + f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') result = await engine.handle_connection("delete_document", { - "db_type": db_type, "userid": userid, - "filename": filename + "filename": filename, + "db_type": db_type, + "knowledge_base_id": knowledge_base_id }) - debug(f'{result=}') - return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) + debug(f'Delete result: {result=}') + status = 200 if result.get("status") == "success" else 400 + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) except Exception as e: - debug(f'删除文件失败: {str(e)}') + error(f'删除文件失败: {str(e)}') return web.json_response({ "status": "error", - "collection_name": f"ragdb_{db_type}", - "message": str(e) + "collection_name": collection_name, + "document_id": "", + "message": str(e), + "status_code": 400 + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + + +async def delete_knowledge_base(request, params_kw, *params, **kw): + debug(f'Received delete_knowledge_base params: {params_kw=}') + se = ServerEnv() + engine = se.engine + userid = params_kw.get('userid', '') + knowledge_base_id = params_kw.get('knowledge_base_id', '') + db_type = params_kw.get('db_type', '') + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + try: + required_fields = ['userid', 'knowledge_base_id'] + missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] + if missing_fields: + raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") + + debug( + f'Calling delete_knowledge_base with: userid={userid}, knowledge_base_id={knowledge_base_id}, db_type={db_type}') + result = await engine.handle_connection("delete_knowledge_base", { + "userid": userid, + "knowledge_base_id": knowledge_base_id, + "db_type": db_type + }) + debug(f'Delete knowledge base result: {result=}') + status = 200 if result.get("status") == "success" else 400 + return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) + except Exception as e: + error(f'删除知识库失败: {str(e)}') + return web.json_response({ + "status": "error", + "collection_name": collection_name, + "document_id": "", + "message": str(e), + "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) async def fused_search_query(request, params_kw, *params, **kw): @@ -299,16 +359,16 @@ async def fused_search_query(request, params_kw, *params, **kw): engine = se.engine query = params_kw.get('query') userid = params_kw.get('userid') - db_type = params_kw.get('db_type') + db_type = params_kw.get('db_type', '') file_paths = params_kw.get('file_paths') limit = params_kw.get('limit', 5) offset = params_kw.get('offset', 0) use_rerank = params_kw.get('use_rerank', True) - if not all([query, userid, db_type, file_paths]): - debug(f'query, userid, db_type 或 file_paths 未提供') + if not all([query, userid, file_paths]): + debug(f'query, userid 或 file_paths 未提供') return web.json_response({ "status": "error", - "message": "query, userid, db_type 或 file_paths 未提供" + "message": "query, userid 或 file_paths 未提供" }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) try: result = await engine.handle_connection("fused_search", { @@ -323,7 +383,7 @@ async def fused_search_query(request, params_kw, *params, **kw): debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'融合搜索失败: {str(e)}') + error(f'融合搜索失败: {str(e)}') return web.json_response({ "status": "error", "message": str(e) @@ -335,16 +395,16 @@ async def search_query(request, params_kw, *params, **kw): engine = se.engine query = params_kw.get('query') userid = params_kw.get('userid') - db_type = params_kw.get('db_type') + db_type = params_kw.get('db_type', '') file_paths = params_kw.get('file_paths') limit = params_kw.get('limit', 5) offset = params_kw.get('offset', 0) use_rerank = params_kw.get('use_rerank', True) - if not all([query, userid, db_type, file_paths]): - debug(f'query, userid, db_type 或 file_paths 未提供') + if not all([query, userid, file_paths]): + debug(f'query, userid 或 file_paths 未提供') return web.json_response({ "status": "error", - "message": "query, userid, db_type 或 file_paths 未提供" + "message": "query, userid 或 file_paths 未提供" }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) try: result = await engine.handle_connection("search_query", { @@ -359,7 +419,7 @@ async def search_query(request, params_kw, *params, **kw): debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'纯向量搜索失败: {str(e)}') + error(f'纯向量搜索失败: {str(e)}') return web.json_response({ "status": "error", "message": str(e) @@ -383,7 +443,7 @@ async def list_user_files(request, params_kw, *params, **kw): debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'查询用户文件列表失败: {str(e)}') + error(f'查询用户文件列表失败: {str(e)}') return web.json_response({ "status": "error", "message": str(e) @@ -406,7 +466,7 @@ async def handle_connection(request, params_kw, *params, **kw): debug(f'{result=}') return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False)) except Exception as e: - debug(f'处理连接操作失败: {str(e)}') + error(f'处理连接操作失败: {str(e)}') return web.json_response({ "status": "error", "message": str(e) @@ -418,7 +478,7 @@ def main(): parser.add_argument('-p', '--port', default='8888') parser.add_argument('connection_path') args = parser.parse_args() - logger.debug(f"Arguments: {args}") + debug(f"Arguments: {args}") Klass = get_connection_class(args.connection_path) se = ServerEnv() se.engine = Klass() diff --git a/llmengine/milvus_connection.py b/llmengine/milvus_connection.py index a396684..6ae04cb 100644 --- a/llmengine/milvus_connection.py +++ b/llmengine/milvus_connection.py @@ -1,10 +1,11 @@ +from appPublic.jsonConfig import getConfig import os -import logging +from appPublic.log import debug, error, info, exception import yaml from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType from threading import Lock from llmengine.base_connection import connection_register -from typing import Dict, List +from typing import Dict, List, Any import aiohttp from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -16,20 +17,6 @@ import numpy as np from py2neo import Graph from scipy.spatial.distance import cosine -logger = logging.getLogger(__name__) - -CONFIG_PATH = os.getenv('CONFIG_PATH', '/share/wangmeihua/rag/conf/milvusconfig.yaml') -try: - with open(CONFIG_PATH, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - MILVUS_DB_PATH = config['database']['milvus_db_path'] - NEO4J_URI = "bolt://10.18.34.18:7687" - NEO4J_USER = "neo4j" - NEO4J_PASSWORD = "261229..wmh" -except Exception as e: - logger.error(f"加载配置文件 {CONFIG_PATH} 失败: {str(e)}") - raise RuntimeError(f"无法加载配置文件: {str(e)}") - # 嵌入缓存 EMBED_CACHE = {} @@ -47,13 +34,18 @@ class MilvusConnection: def __init__(self): if self._initialized: return - self.db_path = MILVUS_DB_PATH - self.neo4j_uri = NEO4J_URI - self.neo4j_user = NEO4J_USER - self.neo4j_password = NEO4J_PASSWORD + try: + config = getConfig() + self.db_path = config['milvus_db'] + self.neo4j_uri = config['neo4j']['uri'] + self.neo4j_user = config['neo4j']['user'] + self.neo4j_password = config['neo4j']['password'] + except KeyError as e: + error(f"配置文件缺少必要字段: {str(e)}\n{exception()}") + raise RuntimeError(f"配置文件缺少必要字段: {str(e)}") self._initialize_connection() self._initialized = True - logger.info(f"MilvusConnection initialized with db_path: {self.db_path}") + info(f"MilvusConnection initialized with db_path: {self.db_path}") def _initialize_connection(self): """初始化 Milvus 连接,确保单一连接""" @@ -61,103 +53,151 @@ class MilvusConnection: db_dir = os.path.dirname(self.db_path) if not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) - logger.debug(f"创建 Milvus 目录: {db_dir}") + debug(f"创建 Milvus 目录: {db_dir}") if not os.access(db_dir, os.W_OK): raise RuntimeError(f"Milvus 目录 {db_dir} 不可写") if not connections.has_connection("default"): connections.connect("default", uri=self.db_path) - logger.debug(f"已连接到 Milvus Lite,路径: {self.db_path}") + debug(f"已连接到 Milvus Lite,路径: {self.db_path}") else: - logger.debug("已存在 Milvus 连接,跳过重复连接") + debug("已存在 Milvus 连接,跳过重复连接") except Exception as e: - logger.error(f"连接 Milvus 失败: {str(e)}") + error(f"连接 Milvus 失败: {str(e)}\n{exception()}") raise RuntimeError(f"连接 Milvus 失败: {str(e)}") async def handle_connection(self, action: str, params: Dict = None) -> Dict: """处理数据库操作""" try: + debug(f"处理操作: action={action}, params={params}") + if not params: + params = {} + # 通用 db_type 验证 + db_type = params.get("db_type", "") + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + if db_type and "_" in db_type: + return {"status": "error", "message": "db_type 不能包含下划线", "collection_name": collection_name, + "document_id": "", "status_code": 400} + if db_type and len(db_type) > 100: + return {"status": "error", "message": "db_type 的长度应小于 100", "collection_name": collection_name, + "document_id": "", "status_code": 400} + if action == "initialize": - if not connections.has_connection("default"): - self._initialize_connection() return {"status": "success", "message": f"Milvus 连接已初始化,路径: {self.db_path}"} elif action == "get_params": return {"status": "success", "params": {"uri": self.db_path}} elif action == "create_collection": - if not params or "db_type" not in params: - return {"status": "error", "message": "缺少 db_type 参数"} - return self._create_collection(params["db_type"]) + return await self._create_collection(db_type) elif action == "delete_collection": - if not params or "db_type" not in params: - return {"status": "error", "message": "缺少 db_type 参数"} - return self._delete_collection(params["db_type"]) + return await self._delete_collection(db_type) elif action == "insert_document": - if not params or "file_path" not in params or "userid" not in params or "db_type" not in params: - return {"status": "error", "message": "缺少 file_path, userid 或 db_type 参数"} - return await self._insert_document( - params["file_path"], - params["userid"], - params["db_type"] - ) + file_path = params.get("file_path", "") + userid = params.get("userid", "") + knowledge_base_id = params.get("knowledge_base_id", "") + if not file_path or not userid or not knowledge_base_id: + return {"status": "error", "message": "file_path、userid 和 knowledge_base_id 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if "_" in userid or "_" in knowledge_base_id: + return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if len(knowledge_base_id) > 100: + return {"status": "error", "message": "knowledge_base_id 的长度应小于 100", + "collection_name": collection_name, "document_id": "", "status_code": 400} + return await self._insert_document(file_path, userid, knowledge_base_id, db_type) elif action == "delete_document": - if not params or "db_type" not in params or "userid" not in params or "filename" not in params: - return {"status": "error", "message": "缺少 db_type, userid 或 filename 参数"} - return self._delete_document( - params["db_type"], - params["userid"], - params["filename"] - ) + userid = params.get("userid", "") + filename = params.get("filename", "") + knowledge_base_id = params.get("knowledge_base_id", "") + if not userid or not filename or not knowledge_base_id: + return {"status": "error", "message": "userid、filename 和 knowledge_base_id 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if "_" in userid or "_" in knowledge_base_id: + return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if len(userid) > 100 or len(filename) > 255 or len(knowledge_base_id) > 100: + return {"status": "error", "message": "userid、filename 或 knowledge_base_id 的长度超出限制", + "collection_name": collection_name, "document_id": "", "status_code": 400} + return await self._delete_document(db_type, userid, filename, knowledge_base_id) + elif action == "delete_knowledge_base": + userid = params.get("userid", "") + knowledge_base_id = params.get("knowledge_base_id", "") + if not userid or not knowledge_base_id: + return {"status": "error", "message": "userid 和 knowledge_base_id 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if "_" in userid or "_" in knowledge_base_id: + return {"status": "error", "message": "userid 和 knowledge_base_id 不能包含下划线", + "collection_name": collection_name, "document_id": "", "status_code": 400} + if len(userid) > 100 or len(knowledge_base_id) > 100: + return {"status": "error", "message": "userid 或 knowledge_base_id 的长度超出限制", + "collection_name": collection_name, "document_id": "", "status_code": 400} + return await self._delete_knowledge_base(db_type, userid, knowledge_base_id) elif action == "fused_search": - if not params or "query" not in params or "userid" not in params or "db_type" not in params or "file_paths" not in params: - return {"status": "error", "message": "缺少 query, userid, db_type 或 file_paths 参数"} + query = params.get("query", "") + userid = params.get("userid", "") + file_paths = params.get("file_paths", []) + if not query or not userid or not file_paths: + return {"status": "error", "message": "query、userid 或 file_paths 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} return await self._fused_search( - params["query"], - params["userid"], - params["db_type"], - params["file_paths"], + query, + userid, + db_type, + file_paths, params.get("limit", 5), params.get("offset", 0), params.get("use_rerank", True) ) elif action == "search_query": - if not params or "query" not in params or "userid" not in params or "db_type" not in params or "file_paths" not in params: - return {"status": "error", "message": "缺少 query, userid, db_type 或 file_paths 参数"} + query = params.get("query", "") + userid = params.get("userid", "") + file_paths = params.get("file_paths", []) + if not query or not userid or not file_paths: + return {"status": "error", "message": "query、userid 或 file_paths 不能为空", + "collection_name": collection_name, "document_id": "", "status_code": 400} return await self._search_query( - params["query"], - params["userid"], - params["db_type"], - params["file_paths"], + query, + userid, + db_type, + file_paths, params.get("limit", 5), params.get("offset", 0), params.get("use_rerank", True) ) elif action == "list_user_files": - if not params or "userid" not in params: - return {"status": "error", "message": "缺少 userid 参数"} - return await self.list_user_files(params["userid"]) + userid = params.get("userid", "") + if not userid: + return {"status": "error", "message": "userid 不能为空", "collection_name": collection_name, + "document_id": "", "status_code": 400} + return await self.list_user_files(userid) else: - return {"status": "error", "message": f"未知的 action: {action}"} + return {"status": "error", "message": f"未知的 action: {action}", "collection_name": collection_name, + "document_id": "", "status_code": 400} except Exception as e: - logger.error(f"处理操作失败: action={action}, 错误: {str(e)}") - return {"status": "error", "message": str(e)} + error(f"处理操作失败: action={action}, 错误: {str(e)}") + return { + "status": "error", + "message": f"服务器错误: {str(e)}", + "collection_name": params.get("db_type", "ragdb") if params else "ragdb", + "document_id": "", + "status_code": 400 + } - def _create_collection(self, db_type: str) -> Dict: + async def _create_collection(self, db_type: str = "") -> Dict: """创建 Milvus 集合""" try: - if not db_type: - raise ValueError("db_type 不能为空") - if "_" in db_type: - raise ValueError("db_type 不能包含下划线") - if len(db_type) > 100: - raise ValueError("db_type 的长度应小于 100") - - collection_name = f"ragdb_{db_type}" + # 根据 db_type 决定集合名称 + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if len(collection_name) > 255: raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") - logger.debug(f"集合名称: {collection_name}") + if db_type and "_" in db_type: + raise ValueError("db_type 不能包含下划线") + if db_type and len(db_type) > 100: + raise ValueError("db_type 的长度应小于 100") + debug(f"集合名称: {collection_name}") fields = [ FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, max_length=36, auto_id=True), FieldSchema(name="userid", dtype=DataType.VARCHAR, max_length=100), + FieldSchema(name="knowledge_base_id", dtype=DataType.VARCHAR, max_length=100), FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=36), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024), @@ -168,7 +208,7 @@ class MilvusConnection: ] schema = CollectionSchema( fields=fields, - description=f"{db_type} 数据集合,跨用户使用,包含 document_id 和元数据字段", + description="统一数据集合,包含用户ID、知识库ID、document_id 和元数据字段", auto_id=True, primary_field="pk", ) @@ -185,26 +225,26 @@ class MilvusConnection: if expected_fields == actual_fields and vector_field is not None and vector_field.dtype == DataType.FLOAT_VECTOR: dim = vector_field.params.get('dim', None) if hasattr(vector_field, 'params') and vector_field.params else None schema_compatible = dim == 1024 - logger.debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, " - f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, " - f"dim={dim if dim is not None else '未定义'}") + debug(f"检查集合 {collection_name} 的 schema: 字段匹配={expected_fields == actual_fields}, " + f"vector_field存在={vector_field is not None}, dtype={vector_field.dtype if vector_field else '无'}, " + f"dim={dim if dim is not None else '未定义'}") if not schema_compatible: - logger.warning(f"集合 {collection_name} 的 schema 不兼容,原因: " - f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, " - f"vector_field: {vector_field is not None}, " - f"dtype: {vector_field.dtype if vector_field else '无'}, " - f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}") + debug(f"集合 {collection_name} 的 schema 不兼容,原因: " + f"字段不匹配: {expected_fields.symmetric_difference(actual_fields) or '无'}, " + f"vector_field: {vector_field is not None}, " + f"dtype: {vector_field.dtype if vector_field else '无'}, " + f"dim: {vector_field.params.get('dim', '未定义') if vector_field and hasattr(vector_field, 'params') and vector_field.params else '未定义'}") utility.drop_collection(collection_name) else: collection.load() - logger.debug(f"集合 {collection_name} 已存在并加载成功") + debug(f"集合 {collection_name} 已存在并加载成功") return { "status": "success", "collection_name": collection_name, "message": f"集合 {collection_name} 已存在" } except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") + error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") return { "status": "error", "collection_name": collection_name, @@ -217,50 +257,48 @@ class MilvusConnection: field_name="vector", index_params={"index_type": "AUTOINDEX", "metric_type": "COSINE"} ) - for field in ["userid", "document_id", "filename", "file_path", "upload_time", "file_type"]: + for field in ["userid", "knowledge_base_id", "document_id", "filename", "file_path", "upload_time", "file_type"]: collection.create_index( field_name=field, index_params={"index_type": "INVERTED"} ) collection.load() - logger.debug(f"成功创建并加载集合: {collection_name}") + debug(f"成功创建并加载集合: {collection_name}") return { "status": "success", "collection_name": collection_name, "message": f"集合 {collection_name} 创建成功" } except Exception as e: - logger.error(f"创建集合 {collection_name} 失败: {str(e)}") + error(f"创建集合 {collection_name} 失败: {str(e)}\n{exception()}") return { "status": "error", "collection_name": collection_name, "message": str(e) } except Exception as e: - logger.error(f"创建集合失败: {str(e)}") + error(f"创建集合失败: {str(e)}") return { "status": "error", "collection_name": collection_name, "message": str(e) } - def _delete_collection(self, db_type: str) -> Dict: + async def _delete_collection(self, db_type: str = "") -> Dict: """删除 Milvus 集合""" try: - if not db_type: - raise ValueError("db_type 不能为空") - if "_" in db_type: - raise ValueError("db_type 不能包含下划线") - if len(db_type) > 100: - raise ValueError("db_type 的长度应小于 100") - - collection_name = f"ragdb_{db_type}" + # 根据 db_type 决定集合名称 + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if len(collection_name) > 255: raise ValueError(f"集合名称 {collection_name} 超过 255 个字符") - logger.debug(f"集合名称: {collection_name}") + if db_type and "_" in db_type: + raise ValueError("db_type 不能包含下划线") + if db_type and len(db_type) > 100: + raise ValueError("db_type 的长度应小于 100") + debug(f"集合名称: {collection_name}") if not utility.has_collection(collection_name): - logger.debug(f"集合 {collection_name} 不存在") + debug(f"集合 {collection_name} 不存在") return { "status": "success", "collection_name": collection_name, @@ -269,49 +307,46 @@ class MilvusConnection: try: utility.drop_collection(collection_name) - logger.debug(f"成功删除集合: {collection_name}") + debug(f"成功删除集合: {collection_name}") return { "status": "success", "collection_name": collection_name, "message": f"集合 {collection_name} 删除成功" } except Exception as e: - logger.error(f"删除集合 {collection_name} 失败: {str(e)}") + error(f"删除集合 {collection_name} 失败: {str(e)}") return { "status": "error", "collection_name": collection_name, "message": str(e) } except Exception as e: - logger.error(f"删除集合失败: {str(e)}") + error(f"删除集合失败: {str(e)}") return { "status": "error", "collection_name": collection_name, "message": str(e) } - async def _insert_document(self, file_path: str, userid: str, db_type: str) -> Dict: + async def _insert_document(self, file_path: str, userid: str, knowledge_base_id: str, db_type: str = "") -> Dict[ + str, Any]: """将文档插入 Milvus 并抽取三元组到 Neo4j""" document_id = str(uuid.uuid4()) - collection_name = f"ragdb_{db_type}" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + debug( + f'Inserting document: file_path={file_path}, userid={userid}, db_type={db_type}, knowledge_base_id={knowledge_base_id}, document_id={document_id}') try: - if not userid or not db_type: - raise ValueError("userid 和 db_type 不能为空") - if "_" in userid or "_" in db_type: - raise ValueError("userid 和 db_type 不能包含下划线") if not os.path.exists(file_path): raise ValueError(f"文件 {file_path} 不存在") - if len(db_type) > 100: - raise ValueError("db_type 的长度应小于 100") supported_formats = {'pdf', 'doc', 'docx', 'xlsx', 'xls', 'ppt', 'pptx', 'csv', 'txt'} ext = file_path.rsplit('.', 1)[1].lower() if '.' in file_path else '' if ext not in supported_formats: raise ValueError(f"不支持的文件格式: {ext}, 支持的格式: {', '.join(supported_formats)}") - logger.info(f"生成 document_id: {document_id} for file: {file_path}") + info(f"生成 document_id: {document_id} for file: {file_path}") - logger.debug(f"加载文件: {file_path}") + debug(f"加载文件: {file_path}") text = fileloader(file_path) if not text or not text.strip(): raise ValueError(f"文件 {file_path} 加载为空") @@ -322,11 +357,11 @@ class MilvusConnection: chunk_overlap=200, length_function=len, ) - logger.debug("开始分片文件内容") + debug("开始分片文件内容") chunks = text_splitter.split_documents([document]) if not chunks: raise ValueError(f"文件 {file_path} 未生成任何文档块") - logger.debug(f"文件 {file_path} 分割为 {len(chunks)} 个文档块") + debug(f"文件 {file_path} 分割为 {len(chunks)} 个文档块") filename = os.path.basename(file_path).rsplit('.', 1)[0] upload_time = datetime.now().isoformat() @@ -334,6 +369,7 @@ class MilvusConnection: for i, chunk in enumerate(chunks): chunk.metadata.update({ 'userid': userid, + 'knowledge_base_id': knowledge_base_id, 'document_id': document_id, 'filename': filename + '.' + ext, 'file_path': file_path, @@ -341,48 +377,57 @@ class MilvusConnection: 'file_type': ext, }) documents.append(chunk) - logger.debug(f"文档块 {i} 元数据: {chunk.metadata}") + debug(f"文档块 {i} 元数据: {chunk.metadata}") - logger.debug(f"确保集合 {collection_name} 存在") - create_result = self._create_collection(db_type) + debug(f"确保集合 {collection_name} 存在") + create_result = await self._create_collection(db_type) if create_result["status"] == "error": raise RuntimeError(f"集合创建失败: {create_result['message']}") - logger.debug("调用嵌入服务生成向量") + debug("调用嵌入服务生成向量") texts = [doc.page_content for doc in documents] embeddings = await self._get_embeddings(texts) await self._insert_to_milvus(collection_name, documents, embeddings) - logger.info(f"成功插入 {len(documents)} 个文档块到 {collection_name}") + info(f"成功插入 {len(documents)} 个文档块到 {collection_name}") - logger.debug("调用三元组抽取服务") + debug("调用三元组抽取服务") try: triples = await self._extract_triples(text) if triples: - logger.debug(f"抽取到 {len(triples)} 个三元组,插入 Neo4j") - kg = KnowledgeGraph(triples=triples, document_id=document_id) + debug(f"抽取到 {len(triples)} 个三元组,插入 Neo4j") + kg = KnowledgeGraph(triples=triples, document_id=document_id, knowledge_base_id=knowledge_base_id, userid=userid) kg.create_graphnodes() kg.create_graphrels() kg.export_data() - logger.info(f"文件 {file_path} 三元组成功插入 Neo4j") + info(f"文件 {file_path} 三元组成功插入 Neo4j") else: - logger.warning(f"文件 {file_path} 未抽取到三元组") + debug(f"文件 {file_path} 未抽取到三元组") except Exception as e: - logger.warning(f"处理三元组失败: {str(e)}, 但不影响 Milvus 插入") + debug(f"处理三元组失败: {str(e)}") + return { + "status": "success", + "document_id": document_id, + "collection_name": collection_name, + "message": f"文件 {file_path} 成功嵌入,但三元组处理失败: {str(e)}", + "status_code": 200 + } return { "status": "success", "document_id": document_id, "collection_name": collection_name, - "message": f"文件 {file_path} 成功嵌入并处理三元组" + "message": f"文件 {file_path} 成功嵌入并处理三元组", + "status_code": 200 } except Exception as e: - logger.error(f"插入文档失败: {str(e)}") + error(f"插入文档失败: {str(e)}") return { "status": "error", "document_id": document_id, "collection_name": collection_name, - "message": str(e) + "message": str(e), + "status_code": 400 } async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: @@ -398,20 +443,20 @@ class MilvusConnection: json={"input": uncached_texts} ) as response: if response.status != 200: - logger.error(f"嵌入服务调用失败,状态码: {response.status}") + error(f"嵌入服务调用失败,状态码: {response.status}") raise RuntimeError(f"嵌入服务调用失败: {response.status}") result = await response.json() if result.get("object") != "list" or not result.get("data"): - logger.error(f"嵌入服务响应格式错误: {result}") + error(f"嵌入服务响应格式错误: {result}") raise RuntimeError("嵌入服务响应格式错误") embeddings = [item["embedding"] for item in result["data"]] for text, embedding in zip(uncached_texts, embeddings): EMBED_CACHE[text] = np.array(embedding) / np.linalg.norm(embedding) - logger.debug(f"成功获取 {len(embeddings)} 个新嵌入向量,缓存大小: {len(EMBED_CACHE)}") + debug(f"成功获取 {len(embeddings)} 个新嵌入向量,缓存大小: {len(EMBED_CACHE)}") # 返回缓存中的嵌入 return [EMBED_CACHE[text] for text in texts] except Exception as e: - logger.error(f"嵌入服务调用失败: {str(e)}") + error(f"嵌入服务调用失败: {str(e)}\n{exception()}") raise RuntimeError(f"嵌入服务调用失败: {str(e)}") async def _extract_triples(self, text: str) -> List[Dict]: @@ -424,20 +469,21 @@ class MilvusConnection: json={"text": text} ) as response: if response.status != 200: - logger.error(f"三元组抽取服务调用失败,状态码: {response.status}") + error(f"三元组抽取服务调用失败,状态码: {response.status}") raise RuntimeError(f"三元组抽取服务调用失败: {response.status}") result = await response.json() if result.get("object") != "list" or not result.get("data"): - logger.error(f"三元组抽取服务响应格式错误: {result}") + error(f"三元组抽取服务响应格式错误: {result}") raise RuntimeError("三元组抽取服务响应格式错误") triples = result["data"] - logger.debug(f"成功抽取 {len(triples)} 个三元组") + debug(f"成功抽取 {len(triples)} 个三元组") return triples except Exception as e: - logger.error(f"三元组抽取服务调用失败: {str(e)}") + error(f"三元组抽取服务调用失败: {str(e)}\n{exception()}") raise RuntimeError(f"三元组抽取服务调用失败: {str(e)}") - async def _insert_to_milvus(self, collection_name: str, documents: List[Document], embeddings: List[List[float]]) -> None: + async def _insert_to_milvus(self, collection_name: str, documents: List[Document], + embeddings: List[List[float]]) -> None: """将文档和嵌入向量插入 Milvus 集合""" try: if not connections.has_connection("default"): @@ -446,6 +492,7 @@ class MilvusConnection: collection.load() data = { "userid": [doc.metadata["userid"] for doc in documents], + "knowledge_base_id": [doc.metadata["knowledge_base_id"] for doc in documents], "document_id": [doc.metadata["document_id"] for doc in documents], "text": [doc.page_content for doc in documents], "vector": embeddings, @@ -456,46 +503,42 @@ class MilvusConnection: } collection.insert([data[field.name] for field in collection.schema.fields if field.name != "pk"]) collection.flush() - logger.debug(f"成功插入 {len(documents)} 个文档到集合 {collection_name}") + debug(f"成功插入 {len(documents)} 个文档到集合 {collection_name}") except Exception as e: - logger.error(f"插入 Milvus 失败: {str(e)}") + error(f"插入 Milvus 失败: {str(e)}") raise RuntimeError(f"插入 Milvus 失败: {str(e)}") - def _delete_document(self, db_type: str, userid: str, filename: str) -> Dict: - """删除用户指定文件数据""" - collection_name = f"ragdb_{db_type}" + async def _delete_document(self, db_type: str, userid: str, filename: str, knowledge_base_id: str) -> Dict[ + str, Any]: + """删除用户指定文件数据,包括 Milvus 和 Neo4j 中的记录""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - if not db_type or "_" in db_type: - raise ValueError("db_type 不能为空且不能包含下划线") - if not userid or "_" in userid: - raise ValueError("userid 不能为空且不能包含下划线") - if not filename: - raise ValueError("filename 不能为空") - if len(db_type) > 100 or len(userid) > 100 or len(filename) > 255: - raise ValueError("db_type、userid 或 filename 的长度超出限制") - if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") + debug(f"集合 {collection_name} 不存在") return { - "status": "error", + "status": "success", "collection_name": collection_name, - "message": f"集合 {collection_name} 不存在" + "document_id": "", + "message": f"集合 {collection_name} 不存在,无需删除", + "status_code": 200 } try: collection = Collection(collection_name) collection.load() - logger.debug(f"加载集合: {collection_name}") + debug(f"加载集合: {collection_name}") except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") + error(f"加载集合 {collection_name} 失败: {str(e)}") return { "status": "error", "collection_name": collection_name, - "message": f"加载集合失败: {str(e)}" + "document_id": "", + "message": f"加载集合失败: {str(e)}", + "status_code": 400 } - expr = f"userid == '{userid}' and filename == '{filename}'" - logger.debug(f"查询表达式: {expr}") + expr = f"userid == '{userid}' and filename == '{filename}' and knowledge_base_id == '{knowledge_base_id}'" + debug(f"查询表达式: {expr}") try: results = collection.query( expr=expr, @@ -503,56 +546,232 @@ class MilvusConnection: limit=1000 ) if not results: - logger.warning(f"没有找到 userid={userid}, filename={filename} 的记录") + debug( + f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录") return { - "status": "error", + "status": "success", "collection_name": collection_name, - "message": f"没有找到 userid={userid}, filename={filename} 的记录" + "document_id": "", + "message": f"没有找到 userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id} 的记录,无需删除", + "status_code": 200 } document_ids = list(set(result["document_id"] for result in results if "document_id" in result)) - logger.debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}") + debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}") except Exception as e: - logger.error(f"查询 document_id 失败: {str(e)}") + error(f"查询 document_id 失败: {str(e)}") return { "status": "error", "collection_name": collection_name, - "message": f"查询失败: {str(e)}" + "document_id": "", + "message": f"查询失败: {str(e)}", + "status_code": 400 } total_deleted = 0 + neo4j_deleted_nodes = 0 + neo4j_deleted_rels = 0 for doc_id in document_ids: try: - delete_expr = f"userid == '{userid}' and document_id == '{doc_id}'" - logger.debug(f"删除表达式: {delete_expr}") + # 删除 Milvus 记录 + delete_expr = f"document_id == '{doc_id}'" + debug(f"删除表达式: {delete_expr}") delete_result = collection.delete(delete_expr) deleted_count = delete_result.delete_count total_deleted += deleted_count - logger.info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条记录") + info(f"成功删除 document_id={doc_id} 的 {deleted_count} 条 Milvus 记录") + + # 删除 Neo4j 三元组 + try: + graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) + query = """ + MATCH (n {document_id: $document_id}) + OPTIONAL MATCH (n)-[r {document_id: $document_id}]->() + WITH collect(r) AS rels, collect(n) AS nodes + FOREACH (r IN rels | DELETE r) + FOREACH (n IN nodes | DELETE n) + RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types + """ + result = graph.run(query, document_id=doc_id).data() + nodes_deleted = result[0]['node_count'] if result else 0 + rels_deleted = result[0]['rel_count'] if result else 0 + rel_types = result[0]['rel_types'] if result else [] + info( + f"成功删除 document_id={doc_id} 的 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}") + neo4j_deleted_nodes += nodes_deleted + neo4j_deleted_rels += rels_deleted + except Exception as e: + error(f"删除 document_id={doc_id} 的 Neo4j 三元组失败: {str(e)}") + continue except Exception as e: - logger.error(f"删除 document_id={doc_id} 失败: {str(e)}") + error(f"删除 document_id={doc_id} 的 Milvus 记录失败: {str(e)}") continue if total_deleted == 0: - logger.warning(f"没有删除任何记录,userid={userid}, filename={filename}") + debug( + f"没有删除任何 Milvus 记录,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") return { - "status": "error", + "status": "success", "collection_name": collection_name, - "message": f"没有删除任何记录,userid={userid}, filename={filename}" + "document_id": "", + "message": f"没有删除任何记录,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", + "status_code": 200 } - logger.info(f"总计删除 {total_deleted} 条记录,userid={userid}, filename={filename}") + info( + f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}") return { "status": "success", "collection_name": collection_name, - "message": f"成功删除 {total_deleted} 条记录,userid={userid}, filename={filename}" + "document_id": ",".join(document_ids), + "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, filename={filename}, knowledge_base_id={knowledge_base_id}", + "status_code": 200 } except Exception as e: - logger.error(f"删除文档失败: {str(e)}") + error(f"删除文档失败: {str(e)}") return { "status": "error", "collection_name": collection_name, - "message": f"删除文档失败: {str(e)}" + "document_id": "", + "message": f"删除文档失败: {str(e)}", + "status_code": 400 + } + + async def _delete_knowledge_base(self, db_type: str, userid: str, knowledge_base_id: str) -> Dict[str, Any]: + """删除用户的整个知识库,包括 Milvus 和 Neo4j 中的记录""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" + try: + if not utility.has_collection(collection_name): + debug(f"集合 {collection_name} 不存在") + return { + "status": "success", + "collection_name": collection_name, + "document_id": "", + "message": f"集合 {collection_name} 不存在,无需删除", + "status_code": 200 + } + + try: + collection = Collection(collection_name) + await asyncio.wait_for(collection.load_async(), timeout=10.0) + debug(f"加载集合: {collection_name}") + except Exception as e: + error(f"加载集合 {collection_name} 失败: {str(e)}") + return { + "status": "error", + "collection_name": collection_name, + "document_id": "", + "message": f"加载集合失败: {str(e)}", + "status_code": 400 + } + + # 查询 Milvus 中的 document_id 列表 + expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" + debug(f"查询表达式: {expr}") + try: + results = collection.query( + expr=expr, + output_fields=["document_id"], + limit=1000 + ) + if not results: + debug(f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录") + # 即使 Milvus 没有记录,仍尝试删除 Neo4j 数据 + else: + document_ids = list(set(result["document_id"] for result in results if "document_id" in result)) + debug(f"找到 {len(document_ids)} 个 document_id: {document_ids}") + except Exception as e: + error(f"查询 document_id 失败: {str(e)}") + return { + "status": "error", + "collection_name": collection_name, + "document_id": "", + "message": f"查询失败: {str(e)}", + "status_code": 400 + } + + # 删除 Milvus 记录 + total_deleted = 0 + document_ids = [] + if results: + try: + delete_expr = f"userid == '{userid}' and knowledge_base_id == '{knowledge_base_id}'" + debug(f"删除表达式: {delete_expr}") + delete_result = collection.delete(delete_expr) + total_deleted = delete_result.delete_count + document_ids = [result["document_id"] for result in results if "document_id" in result] + info(f"成功删除 {total_deleted} 条 Milvus 记录") + except Exception as e: + error(f"删除 Milvus 记录失败: {str(e)}") + return { + "status": "error", + "collection_name": collection_name, + "document_id": "", + "message": f"删除 Milvus 记录失败: {str(e)}", + "status_code": 400 + } + + # 删除 Neo4j 数据 + neo4j_deleted_nodes = 0 + neo4j_deleted_rels = 0 + try: + debug(f"尝试连接 Neo4j: uri={self.neo4j_uri}, user={self.neo4j_user}") + graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) + debug("Neo4j 连接成功") + query = """ + MATCH (n {userid: $userid, knowledge_base_id: $knowledge_base_id}) + OPTIONAL MATCH (n)-[r {userid: $userid, knowledge_base_id: $knowledge_base_id}]->() + WITH collect(r) AS rels, collect(n) AS nodes + FOREACH (r IN rels | DELETE r) + FOREACH (n IN nodes | DELETE n) + RETURN size(nodes) AS node_count, size(rels) AS rel_count, [r IN rels | type(r)] AS rel_types + """ + result = graph.run(query, userid=userid, knowledge_base_id=knowledge_base_id).data() + nodes_deleted = result[0]['node_count'] if result else 0 + rels_deleted = result[0]['rel_count'] if result else 0 + rel_types = result[0]['rel_types'] if result else [] + neo4j_deleted_nodes += nodes_deleted + neo4j_deleted_rels += rels_deleted + info(f"成功删除 {nodes_deleted} 个 Neo4j 节点和 {rels_deleted} 个关系,关系类型: {rel_types}") + except Exception as e: + error(f"删除 Neo4j 数据失败: {str(e)}") + # 继续返回结果,即使 Neo4j 删除失败 + return { + "status": "success", + "collection_name": collection_name, + "document_id": ",".join(document_ids) if document_ids else "", + "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,但 Neo4j 删除失败: {str(e)}", + "status_code": 200 + } + + if total_deleted == 0 and neo4j_deleted_nodes == 0 and neo4j_deleted_rels == 0: + debug(f"没有删除任何记录,userid={userid}, knowledge_base_id={knowledge_base_id}") + return { + "status": "success", + "collection_name": collection_name, + "document_id": "", + "message": f"没有找到 userid={userid}, knowledge_base_id={knowledge_base_id} 的记录,无需删除", + "status_code": 200 + } + + info( + f"总计删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, knowledge_base_id={knowledge_base_id}") + return { + "status": "success", + "collection_name": collection_name, + "document_id": ",".join(document_ids) if document_ids else "", + "message": f"成功删除 {total_deleted} 条 Milvus 记录,{neo4j_deleted_nodes} 个 Neo4j 节点,{neo4j_deleted_rels} 个 Neo4j 关系,userid={userid}, knowledge_base_id={knowledge_base_id}", + "status_code": 200 + } + + except Exception as e: + error(f"删除知识库失败: {str(e)}") + return { + "status": "error", + "collection_name": collection_name, + "document_id": "", + "message": f"删除知识库失败: {str(e)}", + "status_code": 400 } async def _extract_entities(self, query: str) -> List[str]: @@ -567,18 +786,18 @@ class MilvusConnection: json={"query": query} ) as response: if response.status != 200: - logger.error(f"实体识别服务调用失败,状态码: {response.status}") + error(f"实体识别服务调用失败,状态码: {response.status}") raise RuntimeError(f"实体识别服务调用失败: {response.status}") result = await response.json() if result.get("object") != "list" or not result.get("data"): - logger.error(f"实体识别服务响应格式错误: {result}") + error(f"实体识别服务响应格式错误: {result}") raise RuntimeError("实体识别服务响应格式错误") entities = result["data"] unique_entities = list(dict.fromkeys(entities)) # 去重 - logger.debug(f"成功提取 {len(unique_entities)} 个唯一实体: {unique_entities}") + debug(f"成功提取 {len(unique_entities)} 个唯一实体: {unique_entities}") return unique_entities except Exception as e: - logger.error(f"实体识别服务调用失败: {str(e)}") + error(f"实体识别服务调用失败: {str(e)}\n{exception()}") return [] async def _match_triplets(self, query: str, query_entities: List[str], userid: str, document_id: str) -> List[Dict]: @@ -588,7 +807,7 @@ class MilvusConnection: try: graph = Graph(self.neo4j_uri, auth=(self.neo4j_user, self.neo4j_password)) - logger.debug(f"已连接到 Neo4j: {self.neo4j_uri}") + debug(f"已连接到 Neo4j: {self.neo4j_uri}") matched_names = set() for entity in query_entities: @@ -605,9 +824,9 @@ class MilvusConnection: results = graph.run(query, document_id=document_id, entity=normalized_entity).data() for record in results: matched_names.add(record['n.name']) - logger.debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})") + debug(f"实体 {entity} 匹配节点: {record['n.name']} (Levenshtein 相似度: {record['sim']:.2f})") except Exception as e: - logger.warning(f"模糊匹配实体 {entity} 失败: {str(e)}") + debug(f"模糊匹配实体 {entity} 失败: {str(e)}\n{exception()}") continue triplets = [] @@ -633,13 +852,13 @@ class MilvusConnection: 'head_type': '', 'tail_type': '' }) - logger.debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}") + debug(f"从 Neo4j 加载三元组: document_id={document_id}, 数量={len(triplets)}") except Exception as e: - logger.error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}") + error(f"检索三元组失败: document_id={document_id}, 错误: {str(e)}\n{exception()}") return [] if not triplets: - logger.debug(f"文档 document_id={document_id} 无匹配三元组") + debug(f"文档 document_id={document_id} 无匹配三元组") return [] texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets] @@ -647,7 +866,7 @@ class MilvusConnection: entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)} head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)} tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)} - logger.debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)") + debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)") for entity in query_entities: entity_vec = entity_vectors[entity] @@ -659,8 +878,8 @@ class MilvusConnection: if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD: matched_triplets.append(d_triplet) - logger.debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " - f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") + debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " + f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") unique_matched = [] seen = set() @@ -670,71 +889,26 @@ class MilvusConnection: seen.add(identifier) unique_matched.append(t) - logger.info(f"找到 {len(unique_matched)} 个匹配的三元组") + info(f"找到 {len(unique_matched)} 个匹配的三元组") return unique_matched except Exception as e: - logger.error(f"匹配三元组失败: {str(e)}") - return [] - - texts_to_embed = query_entities + [t['head'] for t in triplets] + [t['tail'] for t in triplets] - embeddings = await self._get_embeddings(texts_to_embed) - entity_vectors = {entity: embeddings[i] for i, entity in enumerate(query_entities)} - head_vectors = {t['head']: embeddings[len(query_entities) + i] for i, t in enumerate(triplets)} - tail_vectors = {t['tail']: embeddings[len(query_entities) + len(triplets) + i] for i, t in enumerate(triplets)} - logger.debug(f"成功获取 {len(embeddings)} 个嵌入向量({len(query_entities)} entities + {len(triplets)} heads + {len(triplets)} tails)") - - for entity in query_entities: - entity_vec = entity_vectors[entity] - for d_triplet in triplets: - d_head_vec = head_vectors[d_triplet['head']] - d_tail_vec = tail_vectors[d_triplet['tail']] - head_similarity = 1 - cosine(entity_vec, d_head_vec) - tail_similarity = 1 - cosine(entity_vec, d_tail_vec) - - if head_similarity >= ENTITY_SIMILARITY_THRESHOLD or tail_similarity >= ENTITY_SIMILARITY_THRESHOLD: - matched_triplets.append(d_triplet) - logger.debug(f"匹配三元组: {d_triplet['head']} - {d_triplet['type']} - {d_triplet['tail']} " - f"(entity={entity}, head_sim={head_similarity:.2f}, tail_sim={tail_similarity:.2f})") - - unique_matched = [] - seen = set() - for t in matched_triplets: - identifier = (t['head'].lower(), t['type'].lower(), t['tail'].lower()) - if identifier not in seen: - seen.add(identifier) - unique_matched.append(t) - - logger.info(f"找到 {len(unique_matched)} 个匹配的三元组") - return unique_matched - - except Exception as e: - logger.error(f"匹配三元组失败: {str(e)}") + error(f"匹配三元组失败: {str(e)}\n{exception()}") return [] async def _rerank_results(self, query: str, results: List[Dict], top_n: int) -> List[Dict]: - """调用重排序服务 - - Args: - query (str): 查询字符串 - results (List[Dict]): 原始搜索结果列表 - top_n (int): 返回的重排序结果数量 - - Returns: - List[Dict]: 重排序后的结果列表 - """ + """调用重排序服务""" try: if not results: - logger.debug("无结果需要重排序") + debug("无结果需要重排序") return results - # 验证 top_n if not isinstance(top_n, int) or top_n < 1: - logger.warning(f"无效的 top_n 参数: {top_n},使用 len(results)={len(results)}") + debug(f"无效的 top_n 参数: {top_n},使用 len(results)={len(results)}") top_n = len(results) else: - top_n = min(top_n, len(results)) # 确保不超过结果数量 - logger.debug(f"重排序 top_n={top_n}, 原始结果数={len(results)}") + top_n = min(top_n, len(results)) + debug(f"重排序 top_n={top_n}, 原始结果数={len(results)}") documents = [result["text"] for result in results] async with aiohttp.ClientSession() as session: @@ -749,11 +923,11 @@ class MilvusConnection: } ) as response: if response.status != 200: - logger.error(f"重排序服务调用失败,状态码: {response.status}") + error(f"重排序服务调用失败,状态码: {response.status}") raise RuntimeError(f"重排序服务调用失败: {response.status}") result = await response.json() if result.get("object") != "rerank.result" or not result.get("data"): - logger.error(f"重排序服务响应格式错误: {result}") + error(f"重排序服务响应格式错误: {result}") raise RuntimeError("重排序服务响应格式错误") rerank_data = result["data"] reranked_results = [] @@ -762,79 +936,71 @@ class MilvusConnection: if index < len(results): results[index]["rerank_score"] = item["relevance_score"] reranked_results.append(results[index]) - logger.debug(f"成功重排序 {len(reranked_results)} 条结果") - return reranked_results[:top_n] # 确保返回不超过 top_n 条 + debug(f"成功重排序 {len(reranked_results)} 条结果") + return reranked_results[:top_n] except Exception as e: - logger.error(f"重排序服务调用失败: {str(e)}") - return results # 出错时返回原始结果 + error(f"重排序服务调用失败: {str(e)}\n{exception()}") + return results async def _fused_search(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0, use_rerank: bool = True) -> List[Dict]: """融合搜索,将查询与所有三元组拼接后向量化搜索""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - logger.info( + info( f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}") - collection_name = f"ragdb_{db_type}" - # 参数验证 - if not query or not userid or not db_type or not file_paths: - raise ValueError("query、userid、db_type 和 file_paths 不能为空") - if "_" in userid or "_" in db_type: + if not query or not userid or not file_paths: + raise ValueError("query、userid 和 file_paths 不能为空") + if "_" in userid or (db_type and "_" in db_type): raise ValueError("userid 和 db_type 不能包含下划线") - if len(db_type) > 100 or len(userid) > 100: + if (db_type and len(db_type) > 100) or len(userid) > 100: raise ValueError("db_type 或 userid 的长度超出限制") if limit < 1 or offset < 0: raise ValueError("limit 必须大于 0,offset 必须大于或等于 0") - # 检查集合是否存在 if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") + debug(f"集合 {collection_name} 不存在") return [] - # 加载集合 try: collection = Collection(collection_name) collection.load() - logger.debug(f"加载集合: {collection_name}") + debug(f"加载集合: {collection_name}") except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") + error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") return [] - # 提取实体 query_entities = await self._extract_entities(query) - logger.debug(f"提取实体: {query_entities}") + debug(f"提取实体: {query_entities}") - # 收集 document_id 和三元组 documents = [] all_triplets = [] for file_path in file_paths: filename = os.path.basename(file_path) - logger.debug(f"处理文件: {filename}") + debug(f"处理文件: {filename}") - # 获取 document_id results = collection.query( expr=f"userid == '{userid}' and filename == '{filename}'", output_fields=["document_id", "filename"], limit=1 ) if not results: - logger.warning(f"未找到 userid {userid} 和 filename {filename} 对应的文档") + debug(f"未找到 userid {userid} 和 filename {filename} 对应的文档") continue documents.append(results[0]) - # 获取三元组 document_id = results[0]["document_id"] matched_triplets = await self._match_triplets(query, query_entities, userid, document_id) - logger.debug(f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条") + debug(f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条") all_triplets.extend(matched_triplets) if not documents: - logger.warning("未找到任何有效文档") + debug("未找到任何有效文档") return [] - logger.info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}") + info(f"找到 {len(documents)} 个文档: {[doc['filename'] for doc in documents]}") - # 拼接查询和三元组 triplet_texts = [] for triplet in all_triplets: head = triplet.get('head', '') @@ -843,28 +1009,24 @@ class MilvusConnection: if head and type_ and tail: triplet_texts.append(f"{head} {type_} {tail}") else: - logger.debug(f"无效三元组: {triplet}") + debug(f"无效三元组: {triplet}") combined_text = query if triplet_texts: combined_text += " [三元组] " + "; ".join(triplet_texts) - logger.debug( + debug( f"拼接文本: {combined_text[:200]}... (总长度: {len(combined_text)}, 三元组数量: {len(triplet_texts)})") - # 生成拼接文本的嵌入向量 embeddings = await self._get_embeddings([combined_text]) query_vector = embeddings[0] - logger.debug(f"拼接文本向量维度: {len(query_vector)}") + debug(f"拼接文本向量维度: {len(query_vector)}") - # 构造搜索参数 search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - # 构造过滤表达式 filenames = [os.path.basename(file_path) for file_path in file_paths] filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames]) expr = f"userid == '{userid}' and ({filename_expr})" - logger.debug(f"搜索表达式: {expr}") + debug(f"搜索表达式: {expr}") - # 执行向量搜索 try: results = collection.search( data=[query_vector], @@ -877,10 +1039,9 @@ class MilvusConnection: offset=offset ) except Exception as e: - logger.error(f"向量搜索失败: {str(e)}") + error(f"向量搜索失败: {str(e)}\n{exception()}") return [] - # 处理搜索结果 search_results = [] for hits in results: for hit in hits: @@ -899,50 +1060,47 @@ class MilvusConnection: "metadata": metadata } search_results.append(result) - logger.debug( + debug( f"搜索命中: text={result['text'][:100]}..., distance={hit.distance}, source={result['source']}") - # 去重 unique_results = [] seen_texts = set() for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): if result['text'] not in seen_texts: unique_results.append(result) seen_texts.add(result['text']) - logger.info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") + info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") - # 重排序(可选) if use_rerank and unique_results: - logger.debug("开始重排序") + debug("开始重排序") unique_results = await self._rerank_results(combined_text, unique_results, 5) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - logger.debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - logger.info(f"融合搜索完成,返回 {len(unique_results)} 条结果") + info(f"融合搜索完成,返回 {len(unique_results)} 条结果") return unique_results[:limit] except Exception as e: - logger.error(f"融合搜索失败: {str(e)}") + error(f"融合搜索失败: {str(e)}\n{exception()}") return [] async def _search_query(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, offset: int = 0, use_rerank: bool = True) -> List[Dict]: """纯向量搜索,基于查询文本在指定文档中搜索相关文本块""" + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - logger.info( + info( f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}") - collection_name = f"ragdb_{db_type}" - # 参数验证 if not query: raise ValueError("查询文本不能为空") - if not userid or not db_type: - raise ValueError("userid 和 db_type 不能为空") - if "_" in userid or "_" in db_type: + if not userid: + raise ValueError("userid 不能为空") + if "_" in userid or (db_type and "_" in db_type): raise ValueError("userid 和 db_type 不能包含下划线") - if len(userid) > 100 or len(db_type) > 100: + if (db_type and len(db_type) > 100) or len(userid) > 100: raise ValueError("userid 或 db_type 的长度超出限制") if limit <= 0 or limit > 16384: raise ValueError("limit 必须在 1 到 16384 之间") @@ -960,35 +1118,29 @@ class MilvusConnection: if "_" in os.path.basename(file_path): raise ValueError(f"文件名 {file_path} 不能包含下划线") - # 检查集合是否存在 if not utility.has_collection(collection_name): - logger.warning(f"集合 {collection_name} 不存在") + debug(f"集合 {collection_name} 不存在") return [] - # 加载集合 try: collection = Collection(collection_name) collection.load() - logger.debug(f"加载集合: {collection_name}") + debug(f"加载集合: {collection_name}") except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") + error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") raise RuntimeError(f"加载集合失败: {str(e)}") - # 生成查询向量 embeddings = await self._get_embeddings([query]) query_vector = embeddings[0] - logger.debug(f"查询向量维度: {len(query_vector)}") + debug(f"查询向量维度: {len(query_vector)}") - # 构造搜索参数 search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - # 构造过滤表达式 filenames = [os.path.basename(file_path) for file_path in file_paths] filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames]) expr = f"userid == '{userid}' and ({filename_expr})" - logger.debug(f"搜索表达式: {expr}") + debug(f"搜索表达式: {expr}") - # 执行搜索 try: results = collection.search( data=[query_vector], @@ -1001,10 +1153,9 @@ class MilvusConnection: offset=offset ) except Exception as e: - logger.error(f"搜索失败: {str(e)}") + error(f"搜索失败: {str(e)}\n{exception()}") raise RuntimeError(f"搜索失败: {str(e)}") - # 处理搜索结果 search_results = [] for hits in results: for hit in hits: @@ -1023,48 +1174,37 @@ class MilvusConnection: "metadata": metadata } search_results.append(result) - logger.debug( + debug( f"命中: text={result['text'][:100]}..., distance={hit.distance}, filename={metadata['filename']}") - # 去重 unique_results = [] seen_texts = set() for result in sorted(search_results, key=lambda x: x['distance'], reverse=True): if result['text'] not in seen_texts: unique_results.append(result) seen_texts.add(result['text']) - logger.info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") + info(f"去重后结果数量: {len(unique_results)} (原始数量: {len(search_results)})") - # 重排序(可选) if use_rerank and unique_results: - logger.debug("开始重排序") + debug("开始重排序") unique_results = await self._rerank_results(query, unique_results, 5) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) - logger.debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") + debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: - # 未启用重排序,确保不包含 rerank_score unique_results = [{k: v for k, v in r.items() if k != 'rerank_score'} for r in unique_results] - logger.info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果") + info(f"纯向量搜索完成,返回 {len(unique_results)} 条结果") return unique_results[:limit] except Exception as e: - logger.error(f"纯向量搜索失败: {str(e)}") + error(f"纯向量搜索失败: {str(e)}\n{exception()}") return [] async def list_user_files(self, userid: str) -> List[Dict]: - """根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询 - - Args: - userid (str): 用户 ID - - Returns: - List[Dict]: 文件列表,每个文件包含 filename, file_path, db_type, upload_time, file_type - """ + """根据 userid 返回用户的所有文件列表,从所有 ragdb_ 开头的集合中查询""" try: - logger.info(f"开始查询用户文件列表: userid={userid}") + info(f"开始查询用户文件列表: userid={userid}") - # 参数验证 if not userid: raise ValueError("userid 不能为空") if "_" in userid: @@ -1072,43 +1212,38 @@ class MilvusConnection: if len(userid) > 100: raise ValueError("userid 长度超出限制") - # 获取所有 ragdb_ 开头的集合 - collections = [c for c in utility.list_collections() if c.startswith("ragdb_")] + collections = utility.list_collections() + collections = [c for c in collections if c.startswith("ragdb")] if not collections: - logger.warning("未找到任何 ragdb_ 开头的集合") + debug("未找到任何 ragdb 开头的集合") return [] - logger.debug(f"找到集合: {collections}") + debug(f"找到集合: {collections}") - # 收集文件信息 file_list = [] - seen_files = set() # 用于去重 (filename, file_path) + seen_files = set() for collection_name in collections: - # 提取 db_type - db_type = collection_name.replace("ragdb_", "") - logger.debug(f"处理集合: {collection_name}, db_type={db_type}") + db_type = collection_name.replace("ragdb_", "") if collection_name != "ragdb" else "" + debug(f"处理集合: {collection_name}, db_type={db_type}") - # 加载集合 try: collection = Collection(collection_name) collection.load() - logger.debug(f"加载集合: {collection_name}") + debug(f"加载集合: {collection_name}") except Exception as e: - logger.error(f"加载集合 {collection_name} 失败: {str(e)}") + error(f"加载集合 {collection_name} 失败: {str(e)}\n{exception()}") continue - # 查询文本块 try: results = collection.query( expr=f"userid == '{userid}'", output_fields=["filename", "file_path", "upload_time", "file_type"], - limit=1000 # 假设最大返回 1000 个文本块 + limit=1000 ) - logger.debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块") + debug(f"集合 {collection_name} 查询到 {len(results)} 个文本块") except Exception as e: - logger.error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}") + error(f"查询集合 {collection_name} 失败: userid={userid}, 错误: {str(e)}\n{exception()}") continue - # 处理查询结果 for result in results: filename = result.get("filename") file_path = result.get("file_path") @@ -1123,14 +1258,15 @@ class MilvusConnection: "upload_time": upload_time, "file_type": file_type }) - logger.debug( + debug( f"文件: filename={filename}, file_path={file_path}, db_type={db_type}, upload_time={upload_time}, file_type={file_type}") - logger.info(f"返回 {len(file_list)} 个文件") - return sorted(file_list, key=lambda x: x["upload_time"], reverse=True) # 按上传时间降序排序 + info(f"返回 {len(file_list)} 个文件") + return sorted(file_list, key=lambda x: x["upload_time"], reverse=True) except Exception as e: - logger.error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}") + error(f"查询用户文件列表失败: userid={userid}, 错误: {str(e)}\n{exception()}") return [] + connection_register('Milvus', MilvusConnection) -logger.info("MilvusConnection registered") \ No newline at end of file +info("MilvusConnection registered") \ No newline at end of file