Compare commits

..

No commits in common. "1c783461bd63a6bc0541b9501d9db918d10aa699" and "ce12a20c7db2fb4cf7cca1c50dfb150deb6def5a" have entirely different histories.

3 changed files with 104 additions and 115 deletions

View File

@ -1,6 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
from appPublic.log import debug, error, info, exception import logging
logger = logging.getLogger(__name__)
connection_pathMap = {} connection_pathMap = {}
@ -8,15 +10,15 @@ def connection_register(connection_key, Klass):
"""为给定的连接键注册一个连接类""" """为给定的连接键注册一个连接类"""
global connection_pathMap global connection_pathMap
connection_pathMap[connection_key] = Klass connection_pathMap[connection_key] = Klass
info(f"Registered {connection_key} with class {Klass}") logger.info(f"Registered {connection_key} with class {Klass}")
def get_connection_class(connection_path): def get_connection_class(connection_path):
"""根据连接路径查找对应的连接类""" """根据连接路径查找对应的连接类"""
global connection_pathMap global connection_pathMap
debug(f"connection_pathMap: {connection_pathMap}") logger.debug(f"connection_pathMap: {connection_pathMap}")
klass = connection_pathMap.get(connection_path) klass = connection_pathMap.get(connection_path)
if klass is None: if klass is None:
error(f"{connection_path} has not mapping to a connection class") logger.error(f"{connection_path} has not mapping to a connection class")
raise Exception(f"{connection_path} has not mapping to a connection class") raise Exception(f"{connection_path} has not mapping to a connection class")
return klass return klass

View File

@ -74,7 +74,7 @@ data: {
"query": "苹果公司在北京开设新店", "query": "苹果公司在北京开设新店",
"userid": "user1", "userid": "user1",
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb "db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_ids": ["kb123"], "file_paths": ["/path/to/file.txt"],
"limit": 5, "limit": 5,
"offset": 0, "offset": 0,
"use_rerank": true "use_rerank": true
@ -89,15 +89,15 @@ response:
"metadata": { "metadata": {
"userid": "user1", "userid": "user1",
"document_id": "<uuid>", "document_id": "<uuid>",
"filename": "file.txt", "filename": "test.txt",
"file_path": "/path/to/file.txt", "file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>", "upload_time": "2025-06-27T15:58:00",
"file_type": "txt" "file_type": "txt"
} }
}, },
... ...
] ]
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"} - Error: HTTP 400, {"status": "error", "message": "<error message>"}
6. Search Query Endpoint: 6. Search Query Endpoint:
path: /v1/searchquery path: /v1/searchquery
@ -107,7 +107,7 @@ data: {
"query": "知识图谱的知识融合是什么?", "query": "知识图谱的知识融合是什么?",
"userid": "user1", "userid": "user1",
"db_type": "textdb", // 可选若不提供则使用默认集合 ragdb "db_type": "textdb", // 可选若不提供则使用默认集合 ragdb
"knowledge_base_ids": ["kb123"], "file_paths": ["/path/to/file.txt"],
"limit": 5, "limit": 5,
"offset": 0, "offset": 0,
"use_rerank": true "use_rerank": true
@ -122,9 +122,9 @@ response:
"metadata": { "metadata": {
"userid": "user1", "userid": "user1",
"document_id": "<uuid>", "document_id": "<uuid>",
"filename": "file.txt", "filename": "test.txt",
"file_path": "/path/to/file.txt", "file_path": "/path/to/file.txt",
"upload_time": "<iso_timestamp>", "upload_time": "2025-06-27T15:58:00",
"file_type": "txt" "file_type": "txt"
} }
}, },
@ -142,10 +142,10 @@ data: {
response: response:
- Success: HTTP 200, [ - Success: HTTP 200, [
{ {
"filename": "file.txt", "filename": "test.txt",
"file_path": "/path/to/file.txt", "file_path": "/path/to/file.txt",
"db_type": "textdb", "db_type": "textdb",
"upload_time": "<iso_timestamp>", "upload_time": "2025-06-27T15:58:00",
"file_type": "txt" "file_type": "txt"
}, },
... ...
@ -165,11 +165,11 @@ response:
- Error: HTTP 400, {"status": "error", "message": "<error message>"} - Error: HTTP 400, {"status": "error", "message": "<error message>"}
9. Docs Endpoint: 9. Docs Endpoint:
path: /docs path: /v1/docs
method: GET method: GET
response: This help text response: This help text
10. Delete Knowledge Base Endpoint: 10.Delete Knowledge Base Endpoint:
path: /v1/deleteknowledgebase path: /v1/deleteknowledgebase
method: POST method: POST
headers: {"Content-Type": "application/json"} headers: {"Content-Type": "application/json"}
@ -179,10 +179,10 @@ data: {
"db_type": "textdb" // 可选若不提供则使用默认集合 ragdb "db_type": "textdb" // 可选若不提供则使用默认集合 ragdb
} }
response: response:
- Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "filename": "<filename1,filename2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系userid=<userid>, knowledge_base_id=<knowledge_base_id>", "status_code": 200} - Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系userid=<userid>, knowledge_base_id=<knowledge_base_id>", "status_code": 200}
- Success (no records): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200} - Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200}
- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200} - Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200}
- Error: HTTP 400, {"status": "error", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400} - Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
""" """
def init(): def init():
@ -243,6 +243,7 @@ async def delete_collection(request, params_kw, *params, **kw):
"message": str(e) "message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def insert_file(request, params_kw, *params, **kw): async def insert_file(request, params_kw, *params, **kw):
debug(f'Received params: {params_kw=}') debug(f'Received params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
@ -253,6 +254,7 @@ async def insert_file(request, params_kw, *params, **kw):
knowledge_base_id = params_kw.get('knowledge_base_id', '') knowledge_base_id = params_kw.get('knowledge_base_id', '')
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
# 仅检查必填字段是否存在
required_fields = ['file_path', 'userid', 'knowledge_base_id'] required_fields = ['file_path', 'userid', 'knowledge_base_id']
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]] missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
if missing_fields: if missing_fields:
@ -267,6 +269,7 @@ async def insert_file(request, params_kw, *params, **kw):
"knowledge_base_id": knowledge_base_id "knowledge_base_id": knowledge_base_id
}) })
debug(f'Insert result: {result=}') debug(f'Insert result: {result=}')
# 根据 result 的 status 设置 HTTP 状态码
status = 200 if result.get("status") == "success" else 400 status = 200 if result.get("status") == "success" else 400
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status) return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
except Exception as e: except Exception as e:
@ -278,6 +281,7 @@ async def insert_file(request, params_kw, *params, **kw):
"message": str(e) "message": str(e)
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_file(request, params_kw, *params, **kw): async def delete_file(request, params_kw, *params, **kw):
debug(f'Received delete_file params: {params_kw=}') debug(f'Received delete_file params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
@ -293,7 +297,8 @@ async def delete_file(request, params_kw, *params, **kw):
if missing_fields: if missing_fields:
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}") raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
debug(f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}') debug(
f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}')
result = await engine.handle_connection("delete_document", { result = await engine.handle_connection("delete_document", {
"userid": userid, "userid": userid,
"filename": filename, "filename": filename,
@ -313,6 +318,7 @@ async def delete_file(request, params_kw, *params, **kw):
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def delete_knowledge_base(request, params_kw, *params, **kw): async def delete_knowledge_base(request, params_kw, *params, **kw):
debug(f'Received delete_knowledge_base params: {params_kw=}') debug(f'Received delete_knowledge_base params: {params_kw=}')
se = ServerEnv() se = ServerEnv()
@ -343,7 +349,6 @@ async def delete_knowledge_base(request, params_kw, *params, **kw):
"status": "error", "status": "error",
"collection_name": collection_name, "collection_name": collection_name,
"document_id": "", "document_id": "",
"filename": "",
"message": str(e), "message": str(e),
"status_code": 400 "status_code": 400
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
@ -355,24 +360,22 @@ async def fused_search_query(request, params_kw, *params, **kw):
query = params_kw.get('query') query = params_kw.get('query')
userid = params_kw.get('userid') userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids') file_paths = params_kw.get('file_paths')
limit = params_kw.get('limit') limit = params_kw.get('limit', 5)
offset = params_kw.get('offset', 0) offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True) use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if not all([query, userid, file_paths]):
debug(f'query, userid 或 file_paths 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 file_paths 未提供"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
try: try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("fused_search", { result = await engine.handle_connection("fused_search", {
"query": query, "query": query,
"userid": userid, "userid": userid,
"db_type": db_type, "db_type": db_type,
"knowledge_base_ids": knowledge_base_ids, "file_paths": file_paths,
"limit": limit, "limit": limit,
"offset": offset, "offset": offset,
"use_rerank": use_rerank "use_rerank": use_rerank
@ -383,8 +386,7 @@ async def fused_search_query(request, params_kw, *params, **kw):
error(f'融合搜索失败: {str(e)}') error(f'融合搜索失败: {str(e)}')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"message": str(e), "message": str(e)
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def search_query(request, params_kw, *params, **kw): async def search_query(request, params_kw, *params, **kw):
@ -394,24 +396,22 @@ async def search_query(request, params_kw, *params, **kw):
query = params_kw.get('query') query = params_kw.get('query')
userid = params_kw.get('userid') userid = params_kw.get('userid')
db_type = params_kw.get('db_type', '') db_type = params_kw.get('db_type', '')
knowledge_base_ids = params_kw.get('knowledge_base_ids') file_paths = params_kw.get('file_paths')
limit = params_kw.get('limit') limit = params_kw.get('limit', 5)
offset = params_kw.get('offset') offset = params_kw.get('offset', 0)
use_rerank = params_kw.get('use_rerank', True) use_rerank = params_kw.get('use_rerank', True)
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" if not all([query, userid, file_paths]):
debug(f'query, userid 或 file_paths 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 file_paths 未提供"
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
try: try:
if not all([query, userid, knowledge_base_ids]):
debug(f'query, userid 或 knowledge_base_ids 未提供')
return web.json_response({
"status": "error",
"message": "query, userid 或 knowledge_base_ids 未提供",
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
result = await engine.handle_connection("search_query", { result = await engine.handle_connection("search_query", {
"query": query, "query": query,
"userid": userid, "userid": userid,
"db_type": db_type, "db_type": db_type,
"knowledge_base_ids": knowledge_base_ids, "file_paths": file_paths,
"limit": limit, "limit": limit,
"offset": offset, "offset": offset,
"use_rerank": use_rerank "use_rerank": use_rerank
@ -422,8 +422,7 @@ async def search_query(request, params_kw, *params, **kw):
error(f'纯向量搜索失败: {str(e)}') error(f'纯向量搜索失败: {str(e)}')
return web.json_response({ return web.json_response({
"status": "error", "status": "error",
"message": str(e), "message": str(e)
"collection_name": collection_name
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400) }, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
async def list_user_files(request, params_kw, *params, **kw): async def list_user_files(request, params_kw, *params, **kw):

View File

@ -133,47 +133,32 @@ class MilvusConnection:
elif action == "fused_search": elif action == "fused_search":
query = params.get("query", "") query = params.get("query", "")
userid = params.get("userid", "") userid = params.get("userid", "")
knowledge_base_ids = params.get("knowledge_base_ids", []) file_paths = params.get("file_paths", [])
limit = params.get("limit", 5) if not query or not userid or not file_paths:
if not query or not userid or not knowledge_base_ids: return {"status": "error", "message": "query、userid 或 file_paths 不能为空",
return { "collection_name": collection_name, "document_id": "", "status_code": 400}
"status": "error",
"message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
"document_id": "",
"status_code": 400
}
if limit < 1 or limit > 16384:
return {
"status": "error",
"message": "limit 必须在 1 到 16384 之间",
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
"document_id": "",
"status_code": 400
}
return await self._fused_search( return await self._fused_search(
query, query,
userid, userid,
params.get("db_type", ""), db_type,
knowledge_base_ids, file_paths,
limit, params.get("limit", 5),
params.get("offset", 0), params.get("offset", 0),
params.get("use_rerank", True) params.get("use_rerank", True)
) )
elif action == "search_query": elif action == "search_query":
query = params.get("query", "") query = params.get("query", "")
userid = params.get("userid", "") userid = params.get("userid", "")
limit = params.get("limit", "") file_paths = params.get("file_paths", [])
knowledge_base_ids = params.get("knowledge_base_ids", []) if not query or not userid or not file_paths:
if not query or not userid or not knowledge_base_ids: return {"status": "error", "message": "query、userid 或 file_paths 不能为空",
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
"collection_name": collection_name, "document_id": "", "status_code": 400} "collection_name": collection_name, "document_id": "", "status_code": 400}
return await self._search_query( return await self._search_query(
query, query,
userid, userid,
db_type, db_type,
knowledge_base_ids, file_paths,
limit, params.get("limit", 5),
params.get("offset", 0), params.get("offset", 0),
params.get("use_rerank", True) params.get("use_rerank", True)
) )
@ -295,7 +280,7 @@ class MilvusConnection:
error(f"创建集合失败: {str(e)}") error(f"创建集合失败: {str(e)}")
return { return {
"status": "error", "status": "error",
"collection_name":collection_name, "collection_name": collection_name,
"message": str(e) "message": str(e)
} }
@ -668,6 +653,7 @@ class MilvusConnection:
try: try:
collection = Collection(collection_name) collection = Collection(collection_name)
await asyncio.wait_for(collection.load_async(), timeout=10.0)
debug(f"加载集合: {collection_name}") debug(f"加载集合: {collection_name}")
except Exception as e: except Exception as e:
error(f"加载集合 {collection_name} 失败: {str(e)}") error(f"加载集合 {collection_name} 失败: {str(e)}")
@ -956,22 +942,22 @@ class MilvusConnection:
error(f"重排序服务调用失败: {str(e)}\n{exception()}") error(f"重排序服务调用失败: {str(e)}\n{exception()}")
return results return results
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5, async def _fused_search(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]: offset: int = 0, use_rerank: bool = True) -> List[Dict]:
"""融合搜索,将查询与所有三元组拼接后向量化搜索""" """融合搜索,将查询与所有三元组拼接后向量化搜索"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
info( info(
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query or not userid or not knowledge_base_ids: if not query or not userid or not file_paths:
raise ValueError("query、userid 和 knowledge_base_ids 不能为空") raise ValueError("query、userid 和 file_paths 不能为空")
if "_" in userid or (db_type and "_" in db_type): if "_" in userid or (db_type and "_" in db_type):
raise ValueError("userid 和 db_type 不能包含下划线") raise ValueError("userid 和 db_type 不能包含下划线")
if (db_type and len(db_type) > 100) or len(userid) > 100: if (db_type and len(db_type) > 100) or len(userid) > 100:
raise ValueError("db_type 或 userid 的长度超出限制") raise ValueError("db_type 或 userid 的长度超出限制")
if limit < 1 or limit > 16384 or offset < 0: if limit < 1 or offset < 0:
raise ValueError("limit 必须在 1 到 16384 之间offset 必须大于或等于 0") raise ValueError("limit 必须大于 0offset 必须大于或等于 0")
if not utility.has_collection(collection_name): if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在") debug(f"集合 {collection_name} 不存在")
@ -990,24 +976,24 @@ class MilvusConnection:
documents = [] documents = []
all_triplets = [] all_triplets = []
for kb_id in knowledge_base_ids: for file_path in file_paths:
debug(f"处理知识库: {kb_id}") filename = os.path.basename(file_path)
debug(f"处理文件: {filename}")
results = collection.query( results = collection.query(
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'", expr=f"userid == '{userid}' and filename == '{filename}'",
output_fields=["document_id", "filename", "knowledge_base_id"], output_fields=["document_id", "filename"],
limit=100 # 查询足够多的文档以支持后续过滤 limit=1
) )
if not results: if not results:
debug(f"未找到 userid {userid}knowledge_base_id {kb_id} 对应的文档") debug(f"未找到 userid {userid}filename {filename} 对应的文档")
continue continue
documents.extend(results) documents.append(results[0])
for doc in results: document_id = results[0]["document_id"]
document_id = doc["document_id"] matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id) debug(f"文件 {filename} 匹配三元组: {len(matched_triplets)}")
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)}") all_triplets.extend(matched_triplets)
all_triplets.extend(matched_triplets)
if not documents: if not documents:
debug("未找到任何有效文档") debug("未找到任何有效文档")
@ -1036,8 +1022,9 @@ class MilvusConnection:
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) filenames = [os.path.basename(file_path) for file_path in file_paths]
expr = f"userid == '{userid}' and ({kb_expr})" filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
expr = f"userid == '{userid}' and ({filename_expr})"
debug(f"搜索表达式: {expr}") debug(f"搜索表达式: {expr}")
try: try:
@ -1045,7 +1032,7 @@ class MilvusConnection:
data=[query_vector], data=[query_vector],
anns_field="vector", anns_field="vector",
param=search_params, param=search_params,
limit=100, limit=limit,
expr=expr, expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
"file_type"], "file_type"],
@ -1086,7 +1073,7 @@ class MilvusConnection:
if use_rerank and unique_results: if use_rerank and unique_results:
debug("开始重排序") debug("开始重排序")
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit unique_results = await self._rerank_results(combined_text, unique_results, 5)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else: else:
@ -1099,13 +1086,13 @@ class MilvusConnection:
error(f"融合搜索失败: {str(e)}\n{exception()}") error(f"融合搜索失败: {str(e)}\n{exception()}")
return [] return []
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5, async def _search_query(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5,
offset: int = 0, use_rerank: bool = True) -> List[Dict]: offset: int = 0, use_rerank: bool = True) -> List[Dict]:
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块""" """纯向量搜索,基于查询文本在指定文档中搜索相关文本块"""
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}" collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
try: try:
info( info(
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}") f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
if not query: if not query:
raise ValueError("查询文本不能为空") raise ValueError("查询文本不能为空")
@ -1121,15 +1108,15 @@ class MilvusConnection:
raise ValueError("offset 不能为负数") raise ValueError("offset 不能为负数")
if limit + offset > 16384: if limit + offset > 16384:
raise ValueError("limit + offset 不能超过 16384") raise ValueError("limit + offset 不能超过 16384")
if not knowledge_base_ids: if not file_paths:
raise ValueError("knowledge_base_ids 不能为空") raise ValueError("file_paths 不能为空")
for kb_id in knowledge_base_ids: for file_path in file_paths:
if not isinstance(kb_id, str): if not isinstance(file_path, str):
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}") raise ValueError(f"file_path 必须是字符串: {file_path}")
if len(kb_id) > 100: if len(os.path.basename(file_path)) > 255:
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}") raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
if "_" in kb_id: if "_" in os.path.basename(file_path):
raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}") raise ValueError(f"文件名 {file_path} 不能包含下划线")
if not utility.has_collection(collection_name): if not utility.has_collection(collection_name):
debug(f"集合 {collection_name} 不存在") debug(f"集合 {collection_name} 不存在")
@ -1149,8 +1136,9 @@ class MilvusConnection:
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids]) filenames = [os.path.basename(file_path) for file_path in file_paths]
expr = f"userid == '{userid}' and ({kb_id_expr})" filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
expr = f"userid == '{userid}' and ({filename_expr})"
debug(f"搜索表达式: {expr}") debug(f"搜索表达式: {expr}")
try: try:
@ -1158,7 +1146,7 @@ class MilvusConnection:
data=[query_vector], data=[query_vector],
anns_field="vector", anns_field="vector",
param=search_params, param=search_params,
limit=100, limit=limit,
expr=expr, expr=expr,
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time", output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
"file_type"], "file_type"],
@ -1199,7 +1187,7 @@ class MilvusConnection:
if use_rerank and unique_results: if use_rerank and unique_results:
debug("开始重排序") debug("开始重排序")
unique_results = await self._rerank_results(query, unique_results, limit) unique_results = await self._rerank_results(query, unique_results, 5)
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True) unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}") debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
else: else: