From 313645afee98517089036024933613586b307fd8 Mon Sep 17 00:00:00 2001 From: wangmeihua <13383952685@163.com> Date: Fri, 4 Jul 2025 18:08:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=88=A0=E9=99=A4=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E6=9F=90=E4=B8=AA=E7=9F=A5=E8=AF=86=E5=BA=93=E3=80=81?= =?UTF-8?q?=E6=A0=B9=E6=8D=AE=E7=9F=A5=E8=AF=86=E5=BA=93=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E6=96=87=E6=9C=AC=E5=9D=97=E5=8F=AC=E5=9B=9E=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llmengine/base_connection.py | 10 ++- llmengine/connection.py | 87 +++++++++++------------ llmengine/milvus_connection.py | 122 ++++++++++++++++++--------------- 3 files changed, 115 insertions(+), 104 deletions(-) diff --git a/llmengine/base_connection.py b/llmengine/base_connection.py index a79b841..d15b5dd 100644 --- a/llmengine/base_connection.py +++ b/llmengine/base_connection.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod from typing import Dict -import logging - -logger = logging.getLogger(__name__) +from appPublic.log import debug, error, info, exception connection_pathMap = {} @@ -10,15 +8,15 @@ def connection_register(connection_key, Klass): """为给定的连接键注册一个连接类""" global connection_pathMap connection_pathMap[connection_key] = Klass - logger.info(f"Registered {connection_key} with class {Klass}") + info(f"Registered {connection_key} with class {Klass}") def get_connection_class(connection_path): """根据连接路径查找对应的连接类""" global connection_pathMap - logger.debug(f"connection_pathMap: {connection_pathMap}") + debug(f"connection_pathMap: {connection_pathMap}") klass = connection_pathMap.get(connection_path) if klass is None: - logger.error(f"{connection_path} has not mapping to a connection class") + error(f"{connection_path} has not mapping to a connection class") raise Exception(f"{connection_path} has not mapping to a connection class") return klass diff --git a/llmengine/connection.py b/llmengine/connection.py index 2ead782..fbe9eb6 100644 --- a/llmengine/connection.py +++ b/llmengine/connection.py @@ -74,7 +74,7 @@ data: { "query": "苹果公司在北京开设新店", "userid": "user1", "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb - "file_paths": ["/path/to/file.txt"], + "knowledge_base_ids": ["kb123"], "limit": 5, "offset": 0, "use_rerank": true @@ -89,15 +89,15 @@ response: "metadata": { "userid": "user1", "document_id": "", - "filename": "test.txt", + "filename": "file.txt", "file_path": "/path/to/file.txt", - "upload_time": "2025-06-27T15:58:00", + "upload_time": "", "file_type": "txt" } }, ... ] -- Error: HTTP 400, {"status": "error", "message": ""} +- Error: HTTP 400, {"status": "error", "message": "", "collection_name": ""} 6. Search Query Endpoint: path: /v1/searchquery @@ -107,7 +107,7 @@ data: { "query": "知识图谱的知识融合是什么?", "userid": "user1", "db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb - "file_paths": ["/path/to/file.txt"], + "knowledge_base_ids": ["kb123"], "limit": 5, "offset": 0, "use_rerank": true @@ -122,9 +122,9 @@ response: "metadata": { "userid": "user1", "document_id": "", - "filename": "test.txt", + "filename": "file.txt", "file_path": "/path/to/file.txt", - "upload_time": "2025-06-27T15:58:00", + "upload_time": "", "file_type": "txt" } }, @@ -142,10 +142,10 @@ data: { response: - Success: HTTP 200, [ { - "filename": "test.txt", + "filename": "file.txt", "file_path": "/path/to/file.txt", "db_type": "textdb", - "upload_time": "2025-06-27T15:58:00", + "upload_time": "", "file_type": "txt" }, ... @@ -165,11 +165,11 @@ response: - Error: HTTP 400, {"status": "error", "message": ""} 9. Docs Endpoint: -path: /v1/docs +path: /docs method: GET response: This help text -10.Delete Knowledge Base Endpoint: +10. Delete Knowledge Base Endpoint: path: /v1/deleteknowledgebase method: POST headers: {"Content-Type": "application/json"} @@ -179,10 +179,10 @@ data: { "db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb } response: -- Success: HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 条 Milvus 记录, 个 Neo4j 节点, 个 Neo4j 关系,userid=, knowledge_base_id=", "status_code": 200} -- Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=, knowledge_base_id= 的记录,无需删除", "status_code": 200} -- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 不存在,无需删除", "status_code": 200} -- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} +- Success: HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 条 Milvus 记录, 个 Neo4j 节点, 个 Neo4j 关系,userid=, knowledge_base_id=", "status_code": 200} +- Success (no records): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=, knowledge_base_id= 的记录,无需删除", "status_code": 200} +- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 不存在,无需删除", "status_code": 200} +- Error: HTTP 400, {"status": "error", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "", "status_code": 400} """ def init(): @@ -243,7 +243,6 @@ async def delete_collection(request, params_kw, *params, **kw): "message": str(e) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) - async def insert_file(request, params_kw, *params, **kw): debug(f'Received params: {params_kw=}') se = ServerEnv() @@ -254,7 +253,6 @@ async def insert_file(request, params_kw, *params, **kw): knowledge_base_id = params_kw.get('knowledge_base_id', '') collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: - # 仅检查必填字段是否存在 required_fields = ['file_path', 'userid', 'knowledge_base_id'] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] if missing_fields: @@ -269,7 +267,6 @@ async def insert_file(request, params_kw, *params, **kw): "knowledge_base_id": knowledge_base_id }) debug(f'Insert result: {result=}') - # 根据 result 的 status 设置 HTTP 状态码 status = 200 if result.get("status") == "success" else 400 return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) except Exception as e: @@ -281,7 +278,6 @@ async def insert_file(request, params_kw, *params, **kw): "message": str(e) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) - async def delete_file(request, params_kw, *params, **kw): debug(f'Received delete_file params: {params_kw=}') se = ServerEnv() @@ -297,8 +293,7 @@ async def delete_file(request, params_kw, *params, **kw): if missing_fields: raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") - debug( - f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') + debug(f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') result = await engine.handle_connection("delete_document", { "userid": userid, "filename": filename, @@ -318,7 +313,6 @@ async def delete_file(request, params_kw, *params, **kw): "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) - async def delete_knowledge_base(request, params_kw, *params, **kw): debug(f'Received delete_knowledge_base params: {params_kw=}') se = ServerEnv() @@ -349,6 +343,7 @@ async def delete_knowledge_base(request, params_kw, *params, **kw): "status": "error", "collection_name": collection_name, "document_id": "", + "filename": "", "message": str(e), "status_code": 400 }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) @@ -360,22 +355,24 @@ async def fused_search_query(request, params_kw, *params, **kw): query = params_kw.get('query') userid = params_kw.get('userid') db_type = params_kw.get('db_type', '') - file_paths = params_kw.get('file_paths') - limit = params_kw.get('limit', 5) + knowledge_base_ids = params_kw.get('knowledge_base_ids') + limit = params_kw.get('limit') offset = params_kw.get('offset', 0) use_rerank = params_kw.get('use_rerank', True) - if not all([query, userid, file_paths]): - debug(f'query, userid 或 file_paths 未提供') - return web.json_response({ - "status": "error", - "message": "query, userid 或 file_paths 未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: + if not all([query, userid, knowledge_base_ids]): + debug(f'query, userid 或 knowledge_base_ids 未提供') + return web.json_response({ + "status": "error", + "message": "query, userid 或 knowledge_base_ids 未提供", + "collection_name": collection_name + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) result = await engine.handle_connection("fused_search", { "query": query, "userid": userid, "db_type": db_type, - "file_paths": file_paths, + "knowledge_base_ids": knowledge_base_ids, "limit": limit, "offset": offset, "use_rerank": use_rerank @@ -386,7 +383,8 @@ async def fused_search_query(request, params_kw, *params, **kw): error(f'融合搜索失败: {str(e)}') return web.json_response({ "status": "error", - "message": str(e) + "message": str(e), + "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) async def search_query(request, params_kw, *params, **kw): @@ -396,22 +394,24 @@ async def search_query(request, params_kw, *params, **kw): query = params_kw.get('query') userid = params_kw.get('userid') db_type = params_kw.get('db_type', '') - file_paths = params_kw.get('file_paths') - limit = params_kw.get('limit', 5) - offset = params_kw.get('offset', 0) + knowledge_base_ids = params_kw.get('knowledge_base_ids') + limit = params_kw.get('limit') + offset = params_kw.get('offset') use_rerank = params_kw.get('use_rerank', True) - if not all([query, userid, file_paths]): - debug(f'query, userid 或 file_paths 未提供') - return web.json_response({ - "status": "error", - "message": "query, userid 或 file_paths 未提供" - }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) + collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: + if not all([query, userid, knowledge_base_ids]): + debug(f'query, userid 或 knowledge_base_ids 未提供') + return web.json_response({ + "status": "error", + "message": "query, userid 或 knowledge_base_ids 未提供", + "collection_name": collection_name + }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) result = await engine.handle_connection("search_query", { "query": query, "userid": userid, "db_type": db_type, - "file_paths": file_paths, + "knowledge_base_ids": knowledge_base_ids, "limit": limit, "offset": offset, "use_rerank": use_rerank @@ -422,7 +422,8 @@ async def search_query(request, params_kw, *params, **kw): error(f'纯向量搜索失败: {str(e)}') return web.json_response({ "status": "error", - "message": str(e) + "message": str(e), + "collection_name": collection_name }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) async def list_user_files(request, params_kw, *params, **kw): diff --git a/llmengine/milvus_connection.py b/llmengine/milvus_connection.py index 6ae04cb..5ba3303 100644 --- a/llmengine/milvus_connection.py +++ b/llmengine/milvus_connection.py @@ -133,32 +133,47 @@ class MilvusConnection: elif action == "fused_search": query = params.get("query", "") userid = params.get("userid", "") - file_paths = params.get("file_paths", []) - if not query or not userid or not file_paths: - return {"status": "error", "message": "query、userid 或 file_paths 不能为空", - "collection_name": collection_name, "document_id": "", "status_code": 400} + knowledge_base_ids = params.get("knowledge_base_ids", []) + limit = params.get("limit", 5) + if not query or not userid or not knowledge_base_ids: + return { + "status": "error", + "message": "query、userid 或 knowledge_base_ids 不能为空", + "collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}", + "document_id": "", + "status_code": 400 + } + if limit < 1 or limit > 16384: + return { + "status": "error", + "message": "limit 必须在 1 到 16384 之间", + "collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}", + "document_id": "", + "status_code": 400 + } return await self._fused_search( query, userid, - db_type, - file_paths, - params.get("limit", 5), + params.get("db_type", ""), + knowledge_base_ids, + limit, params.get("offset", 0), params.get("use_rerank", True) ) elif action == "search_query": query = params.get("query", "") userid = params.get("userid", "") - file_paths = params.get("file_paths", []) - if not query or not userid or not file_paths: - return {"status": "error", "message": "query、userid 或 file_paths 不能为空", + limit = params.get("limit", "") + knowledge_base_ids = params.get("knowledge_base_ids", []) + if not query or not userid or not knowledge_base_ids: + return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空", "collection_name": collection_name, "document_id": "", "status_code": 400} return await self._search_query( query, userid, db_type, - file_paths, - params.get("limit", 5), + knowledge_base_ids, + limit, params.get("offset", 0), params.get("use_rerank", True) ) @@ -280,7 +295,7 @@ class MilvusConnection: error(f"创建集合失败: {str(e)}") return { "status": "error", - "collection_name": collection_name, + "collection_name":collection_name, "message": str(e) } @@ -653,7 +668,6 @@ class MilvusConnection: try: collection = Collection(collection_name) - await asyncio.wait_for(collection.load_async(), timeout=10.0) debug(f"加载集合: {collection_name}") except Exception as e: error(f"加载集合 {collection_name} 失败: {str(e)}") @@ -942,22 +956,22 @@ class MilvusConnection: error(f"重排序服务调用失败: {str(e)}\n{exception()}") return results - async def _fused_search(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, + async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5, offset: int = 0, use_rerank: bool = True) -> List[Dict]: """融合搜索,将查询与所有三元组拼接后向量化搜索""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: info( - f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}") + f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") - if not query or not userid or not file_paths: - raise ValueError("query、userid 和 file_paths 不能为空") + if not query or not userid or not knowledge_base_ids: + raise ValueError("query、userid 和 knowledge_base_ids 不能为空") if "_" in userid or (db_type and "_" in db_type): raise ValueError("userid 和 db_type 不能包含下划线") if (db_type and len(db_type) > 100) or len(userid) > 100: raise ValueError("db_type 或 userid 的长度超出限制") - if limit < 1 or offset < 0: - raise ValueError("limit 必须大于 0,offset 必须大于或等于 0") + if limit < 1 or limit > 16384 or offset < 0: + raise ValueError("limit 必须在 1 到 16384 之间,offset 必须大于或等于 0") if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") @@ -976,24 +990,24 @@ class MilvusConnection: documents = [] all_triplets = [] - for file_path in file_paths: - filename = os.path.basename(file_path) - debug(f"处理文件: {filename}") + for kb_id in knowledge_base_ids: + debug(f"处理知识库: {kb_id}") results = collection.query( - expr=f"userid == '{userid}' and filename == '{filename}'", - output_fields=["document_id", "filename"], - limit=1 + expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'", + output_fields=["document_id", "filename", "knowledge_base_id"], + limit=100 # 查询足够多的文档以支持后续过滤 ) if not results: - debug(f"未找到 userid {userid} 和 filename {filename} 对应的文档") + debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档") continue - documents.append(results[0]) + documents.extend(results) - document_id = results[0]["document_id"] - matched_triplets = await self._match_triplets(query, query_entities, userid, document_id) - debug(f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条") - all_triplets.extend(matched_triplets) + for doc in results: + document_id = doc["document_id"] + matched_triplets = await self._match_triplets(query, query_entities, userid, document_id) + debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)} 条") + all_triplets.extend(matched_triplets) if not documents: debug("未找到任何有效文档") @@ -1022,9 +1036,8 @@ class MilvusConnection: search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - filenames = [os.path.basename(file_path) for file_path in file_paths] - filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames]) - expr = f"userid == '{userid}' and ({filename_expr})" + kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) + expr = f"userid == '{userid}' and ({kb_expr})" debug(f"搜索表达式: {expr}") try: @@ -1032,7 +1045,7 @@ class MilvusConnection: data=[query_vector], anns_field="vector", param=search_params, - limit=limit, + limit=100, expr=expr, output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], @@ -1073,7 +1086,7 @@ class MilvusConnection: if use_rerank and unique_results: debug("开始重排序") - unique_results = await self._rerank_results(combined_text, unique_results, 5) + unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: @@ -1086,13 +1099,13 @@ class MilvusConnection: error(f"融合搜索失败: {str(e)}\n{exception()}") return [] - async def _search_query(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5, - offset: int = 0, use_rerank: bool = True) -> List[Dict]: - """纯向量搜索,基于查询文本在指定文档中搜索相关文本块""" + async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5, + offset: int = 0, use_rerank: bool = True) -> List[Dict]: + """纯向量搜索,基于查询文本在指定知识库中搜索相关文本块""" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" try: info( - f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}") + f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") if not query: raise ValueError("查询文本不能为空") @@ -1108,15 +1121,15 @@ class MilvusConnection: raise ValueError("offset 不能为负数") if limit + offset > 16384: raise ValueError("limit + offset 不能超过 16384") - if not file_paths: - raise ValueError("file_paths 不能为空") - for file_path in file_paths: - if not isinstance(file_path, str): - raise ValueError(f"file_path 必须是字符串: {file_path}") - if len(os.path.basename(file_path)) > 255: - raise ValueError(f"文件名长度超出 255 个字符: {file_path}") - if "_" in os.path.basename(file_path): - raise ValueError(f"文件名 {file_path} 不能包含下划线") + if not knowledge_base_ids: + raise ValueError("knowledge_base_ids 不能为空") + for kb_id in knowledge_base_ids: + if not isinstance(kb_id, str): + raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}") + if len(kb_id) > 100: + raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}") + if "_" in kb_id: + raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}") if not utility.has_collection(collection_name): debug(f"集合 {collection_name} 不存在") @@ -1136,9 +1149,8 @@ class MilvusConnection: search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - filenames = [os.path.basename(file_path) for file_path in file_paths] - filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames]) - expr = f"userid == '{userid}' and ({filename_expr})" + kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) + expr = f"userid == '{userid}' and ({kb_id_expr})" debug(f"搜索表达式: {expr}") try: @@ -1146,7 +1158,7 @@ class MilvusConnection: data=[query_vector], anns_field="vector", param=search_params, - limit=limit, + limit=100, expr=expr, output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", "file_type"], @@ -1187,7 +1199,7 @@ class MilvusConnection: if use_rerank and unique_results: debug("开始重排序") - unique_results = await self._rerank_results(query, unique_results, 5) + unique_results = await self._rerank_results(query, unique_results, limit) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") else: