添加删除用户某个知识库、根据知识库进行文本块召回功能
This commit is contained in:
parent
5fcc519720
commit
313645afee
@ -1,8 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import logging
|
from appPublic.log import debug, error, info, exception
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
connection_pathMap = {}
|
connection_pathMap = {}
|
||||||
|
|
||||||
@ -10,15 +8,15 @@ def connection_register(connection_key, Klass):
|
|||||||
"""为给定的连接键注册一个连接类"""
|
"""为给定的连接键注册一个连接类"""
|
||||||
global connection_pathMap
|
global connection_pathMap
|
||||||
connection_pathMap[connection_key] = Klass
|
connection_pathMap[connection_key] = Klass
|
||||||
logger.info(f"Registered {connection_key} with class {Klass}")
|
info(f"Registered {connection_key} with class {Klass}")
|
||||||
|
|
||||||
def get_connection_class(connection_path):
|
def get_connection_class(connection_path):
|
||||||
"""根据连接路径查找对应的连接类"""
|
"""根据连接路径查找对应的连接类"""
|
||||||
global connection_pathMap
|
global connection_pathMap
|
||||||
logger.debug(f"connection_pathMap: {connection_pathMap}")
|
debug(f"connection_pathMap: {connection_pathMap}")
|
||||||
klass = connection_pathMap.get(connection_path)
|
klass = connection_pathMap.get(connection_path)
|
||||||
if klass is None:
|
if klass is None:
|
||||||
logger.error(f"{connection_path} has not mapping to a connection class")
|
error(f"{connection_path} has not mapping to a connection class")
|
||||||
raise Exception(f"{connection_path} has not mapping to a connection class")
|
raise Exception(f"{connection_path} has not mapping to a connection class")
|
||||||
return klass
|
return klass
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ data: {
|
|||||||
"query": "苹果公司在北京开设新店",
|
"query": "苹果公司在北京开设新店",
|
||||||
"userid": "user1",
|
"userid": "user1",
|
||||||
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
||||||
"file_paths": ["/path/to/file.txt"],
|
"knowledge_base_ids": ["kb123"],
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"offset": 0,
|
"offset": 0,
|
||||||
"use_rerank": true
|
"use_rerank": true
|
||||||
@ -89,15 +89,15 @@ response:
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"userid": "user1",
|
"userid": "user1",
|
||||||
"document_id": "<uuid>",
|
"document_id": "<uuid>",
|
||||||
"filename": "test.txt",
|
"filename": "file.txt",
|
||||||
"file_path": "/path/to/file.txt",
|
"file_path": "/path/to/file.txt",
|
||||||
"upload_time": "2025-06-27T15:58:00",
|
"upload_time": "<iso_timestamp>",
|
||||||
"file_type": "txt"
|
"file_type": "txt"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
- Error: HTTP 400, {"status": "error", "message": "<error message>", "collection_name": "<collection_name>"}
|
||||||
|
|
||||||
6. Search Query Endpoint:
|
6. Search Query Endpoint:
|
||||||
path: /v1/searchquery
|
path: /v1/searchquery
|
||||||
@ -107,7 +107,7 @@ data: {
|
|||||||
"query": "知识图谱的知识融合是什么?",
|
"query": "知识图谱的知识融合是什么?",
|
||||||
"userid": "user1",
|
"userid": "user1",
|
||||||
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
"db_type": "textdb", // 可选,若不提供则使用默认集合 ragdb
|
||||||
"file_paths": ["/path/to/file.txt"],
|
"knowledge_base_ids": ["kb123"],
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"offset": 0,
|
"offset": 0,
|
||||||
"use_rerank": true
|
"use_rerank": true
|
||||||
@ -122,9 +122,9 @@ response:
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"userid": "user1",
|
"userid": "user1",
|
||||||
"document_id": "<uuid>",
|
"document_id": "<uuid>",
|
||||||
"filename": "test.txt",
|
"filename": "file.txt",
|
||||||
"file_path": "/path/to/file.txt",
|
"file_path": "/path/to/file.txt",
|
||||||
"upload_time": "2025-06-27T15:58:00",
|
"upload_time": "<iso_timestamp>",
|
||||||
"file_type": "txt"
|
"file_type": "txt"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -142,10 +142,10 @@ data: {
|
|||||||
response:
|
response:
|
||||||
- Success: HTTP 200, [
|
- Success: HTTP 200, [
|
||||||
{
|
{
|
||||||
"filename": "test.txt",
|
"filename": "file.txt",
|
||||||
"file_path": "/path/to/file.txt",
|
"file_path": "/path/to/file.txt",
|
||||||
"db_type": "textdb",
|
"db_type": "textdb",
|
||||||
"upload_time": "2025-06-27T15:58:00",
|
"upload_time": "<iso_timestamp>",
|
||||||
"file_type": "txt"
|
"file_type": "txt"
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
@ -165,11 +165,11 @@ response:
|
|||||||
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
- Error: HTTP 400, {"status": "error", "message": "<error message>"}
|
||||||
|
|
||||||
9. Docs Endpoint:
|
9. Docs Endpoint:
|
||||||
path: /v1/docs
|
path: /docs
|
||||||
method: GET
|
method: GET
|
||||||
response: This help text
|
response: This help text
|
||||||
|
|
||||||
10.Delete Knowledge Base Endpoint:
|
10. Delete Knowledge Base Endpoint:
|
||||||
path: /v1/deleteknowledgebase
|
path: /v1/deleteknowledgebase
|
||||||
method: POST
|
method: POST
|
||||||
headers: {"Content-Type": "application/json"}
|
headers: {"Content-Type": "application/json"}
|
||||||
@ -179,10 +179,10 @@ data: {
|
|||||||
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
"db_type": "textdb" // 可选,若不提供则使用默认集合 ragdb
|
||||||
}
|
}
|
||||||
response:
|
response:
|
||||||
- Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系,userid=<userid>, knowledge_base_id=<knowledge_base_id>", "status_code": 200}
|
- Success: HTTP 200, {"status": "success", "document_id": "<uuid1,uuid2>", "filename": "<filename1,filename2>", "collection_name": "ragdb" or "ragdb_textdb", "message": "成功删除 <count> 条 Milvus 记录,<nodes> 个 Neo4j 节点,<rels> 个 Neo4j 关系,userid=<userid>, knowledge_base_id=<knowledge_base_id>", "status_code": 200}
|
||||||
- Success (no records): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200}
|
- Success (no records): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "没有找到 userid=<userid>, knowledge_base_id=<knowledge_base_id> 的记录,无需删除", "status_code": 200}
|
||||||
- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200}
|
- Success (collection missing): HTTP 200, {"status": "success", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "集合 <collection_name> 不存在,无需删除", "status_code": 200}
|
||||||
- Error: HTTP 400, {"status": "error", "document_id": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
|
- Error: HTTP 400, {"status": "error", "document_id": "", "filename": "", "collection_name": "ragdb" or "ragdb_textdb", "message": "<error message>", "status_code": 400}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
@ -243,7 +243,6 @@ async def delete_collection(request, params_kw, *params, **kw):
|
|||||||
"message": str(e)
|
"message": str(e)
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
|
|
||||||
async def insert_file(request, params_kw, *params, **kw):
|
async def insert_file(request, params_kw, *params, **kw):
|
||||||
debug(f'Received params: {params_kw=}')
|
debug(f'Received params: {params_kw=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
@ -254,7 +253,6 @@ async def insert_file(request, params_kw, *params, **kw):
|
|||||||
knowledge_base_id = params_kw.get('knowledge_base_id', '')
|
knowledge_base_id = params_kw.get('knowledge_base_id', '')
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
# 仅检查必填字段是否存在
|
|
||||||
required_fields = ['file_path', 'userid', 'knowledge_base_id']
|
required_fields = ['file_path', 'userid', 'knowledge_base_id']
|
||||||
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
missing_fields = [field for field in required_fields if field not in params_kw or not params_kw[field]]
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
@ -269,7 +267,6 @@ async def insert_file(request, params_kw, *params, **kw):
|
|||||||
"knowledge_base_id": knowledge_base_id
|
"knowledge_base_id": knowledge_base_id
|
||||||
})
|
})
|
||||||
debug(f'Insert result: {result=}')
|
debug(f'Insert result: {result=}')
|
||||||
# 根据 result 的 status 设置 HTTP 状态码
|
|
||||||
status = 200 if result.get("status") == "success" else 400
|
status = 200 if result.get("status") == "success" else 400
|
||||||
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
return web.json_response(result, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=status)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -281,7 +278,6 @@ async def insert_file(request, params_kw, *params, **kw):
|
|||||||
"message": str(e)
|
"message": str(e)
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
|
|
||||||
async def delete_file(request, params_kw, *params, **kw):
|
async def delete_file(request, params_kw, *params, **kw):
|
||||||
debug(f'Received delete_file params: {params_kw=}')
|
debug(f'Received delete_file params: {params_kw=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
@ -297,8 +293,7 @@ async def delete_file(request, params_kw, *params, **kw):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
raise ValueError(f"缺少必填字段: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
debug(
|
debug(f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}')
|
||||||
f'Calling delete_document with: userid={userid}, filename={filename}, db_type={db_type}, knowledge_base_id={knowledge_base_id}')
|
|
||||||
result = await engine.handle_connection("delete_document", {
|
result = await engine.handle_connection("delete_document", {
|
||||||
"userid": userid,
|
"userid": userid,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
@ -318,7 +313,6 @@ async def delete_file(request, params_kw, *params, **kw):
|
|||||||
"status_code": 400
|
"status_code": 400
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
|
|
||||||
async def delete_knowledge_base(request, params_kw, *params, **kw):
|
async def delete_knowledge_base(request, params_kw, *params, **kw):
|
||||||
debug(f'Received delete_knowledge_base params: {params_kw=}')
|
debug(f'Received delete_knowledge_base params: {params_kw=}')
|
||||||
se = ServerEnv()
|
se = ServerEnv()
|
||||||
@ -349,6 +343,7 @@ async def delete_knowledge_base(request, params_kw, *params, **kw):
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
"collection_name": collection_name,
|
"collection_name": collection_name,
|
||||||
"document_id": "",
|
"document_id": "",
|
||||||
|
"filename": "",
|
||||||
"message": str(e),
|
"message": str(e),
|
||||||
"status_code": 400
|
"status_code": 400
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
@ -360,22 +355,24 @@ async def fused_search_query(request, params_kw, *params, **kw):
|
|||||||
query = params_kw.get('query')
|
query = params_kw.get('query')
|
||||||
userid = params_kw.get('userid')
|
userid = params_kw.get('userid')
|
||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
file_paths = params_kw.get('file_paths')
|
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||||
limit = params_kw.get('limit', 5)
|
limit = params_kw.get('limit')
|
||||||
offset = params_kw.get('offset', 0)
|
offset = params_kw.get('offset', 0)
|
||||||
use_rerank = params_kw.get('use_rerank', True)
|
use_rerank = params_kw.get('use_rerank', True)
|
||||||
if not all([query, userid, file_paths]):
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
debug(f'query, userid 或 file_paths 未提供')
|
|
||||||
return web.json_response({
|
|
||||||
"status": "error",
|
|
||||||
"message": "query, userid 或 file_paths 未提供"
|
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
|
||||||
try:
|
try:
|
||||||
|
if not all([query, userid, knowledge_base_ids]):
|
||||||
|
debug(f'query, userid 或 knowledge_base_ids 未提供')
|
||||||
|
return web.json_response({
|
||||||
|
"status": "error",
|
||||||
|
"message": "query, userid 或 knowledge_base_ids 未提供",
|
||||||
|
"collection_name": collection_name
|
||||||
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
result = await engine.handle_connection("fused_search", {
|
result = await engine.handle_connection("fused_search", {
|
||||||
"query": query,
|
"query": query,
|
||||||
"userid": userid,
|
"userid": userid,
|
||||||
"db_type": db_type,
|
"db_type": db_type,
|
||||||
"file_paths": file_paths,
|
"knowledge_base_ids": knowledge_base_ids,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"offset": offset,
|
"offset": offset,
|
||||||
"use_rerank": use_rerank
|
"use_rerank": use_rerank
|
||||||
@ -386,7 +383,8 @@ async def fused_search_query(request, params_kw, *params, **kw):
|
|||||||
error(f'融合搜索失败: {str(e)}')
|
error(f'融合搜索失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e)
|
"message": str(e),
|
||||||
|
"collection_name": collection_name
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
async def search_query(request, params_kw, *params, **kw):
|
async def search_query(request, params_kw, *params, **kw):
|
||||||
@ -396,22 +394,24 @@ async def search_query(request, params_kw, *params, **kw):
|
|||||||
query = params_kw.get('query')
|
query = params_kw.get('query')
|
||||||
userid = params_kw.get('userid')
|
userid = params_kw.get('userid')
|
||||||
db_type = params_kw.get('db_type', '')
|
db_type = params_kw.get('db_type', '')
|
||||||
file_paths = params_kw.get('file_paths')
|
knowledge_base_ids = params_kw.get('knowledge_base_ids')
|
||||||
limit = params_kw.get('limit', 5)
|
limit = params_kw.get('limit')
|
||||||
offset = params_kw.get('offset', 0)
|
offset = params_kw.get('offset')
|
||||||
use_rerank = params_kw.get('use_rerank', True)
|
use_rerank = params_kw.get('use_rerank', True)
|
||||||
if not all([query, userid, file_paths]):
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
debug(f'query, userid 或 file_paths 未提供')
|
|
||||||
return web.json_response({
|
|
||||||
"status": "error",
|
|
||||||
"message": "query, userid 或 file_paths 未提供"
|
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
|
||||||
try:
|
try:
|
||||||
|
if not all([query, userid, knowledge_base_ids]):
|
||||||
|
debug(f'query, userid 或 knowledge_base_ids 未提供')
|
||||||
|
return web.json_response({
|
||||||
|
"status": "error",
|
||||||
|
"message": "query, userid 或 knowledge_base_ids 未提供",
|
||||||
|
"collection_name": collection_name
|
||||||
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
result = await engine.handle_connection("search_query", {
|
result = await engine.handle_connection("search_query", {
|
||||||
"query": query,
|
"query": query,
|
||||||
"userid": userid,
|
"userid": userid,
|
||||||
"db_type": db_type,
|
"db_type": db_type,
|
||||||
"file_paths": file_paths,
|
"knowledge_base_ids": knowledge_base_ids,
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"offset": offset,
|
"offset": offset,
|
||||||
"use_rerank": use_rerank
|
"use_rerank": use_rerank
|
||||||
@ -422,7 +422,8 @@ async def search_query(request, params_kw, *params, **kw):
|
|||||||
error(f'纯向量搜索失败: {str(e)}')
|
error(f'纯向量搜索失败: {str(e)}')
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e)
|
"message": str(e),
|
||||||
|
"collection_name": collection_name
|
||||||
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
}, dumps=lambda obj: json.dumps(obj, ensure_ascii=False), status=400)
|
||||||
|
|
||||||
async def list_user_files(request, params_kw, *params, **kw):
|
async def list_user_files(request, params_kw, *params, **kw):
|
||||||
|
@ -133,32 +133,47 @@ class MilvusConnection:
|
|||||||
elif action == "fused_search":
|
elif action == "fused_search":
|
||||||
query = params.get("query", "")
|
query = params.get("query", "")
|
||||||
userid = params.get("userid", "")
|
userid = params.get("userid", "")
|
||||||
file_paths = params.get("file_paths", [])
|
knowledge_base_ids = params.get("knowledge_base_ids", [])
|
||||||
if not query or not userid or not file_paths:
|
limit = params.get("limit", 5)
|
||||||
return {"status": "error", "message": "query、userid 或 file_paths 不能为空",
|
if not query or not userid or not knowledge_base_ids:
|
||||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "query、userid 或 knowledge_base_ids 不能为空",
|
||||||
|
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
|
||||||
|
"document_id": "",
|
||||||
|
"status_code": 400
|
||||||
|
}
|
||||||
|
if limit < 1 or limit > 16384:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "limit 必须在 1 到 16384 之间",
|
||||||
|
"collection_name": "ragdb" if not params.get("db_type","") else f"ragdb_{params.get('db_type')}",
|
||||||
|
"document_id": "",
|
||||||
|
"status_code": 400
|
||||||
|
}
|
||||||
return await self._fused_search(
|
return await self._fused_search(
|
||||||
query,
|
query,
|
||||||
userid,
|
userid,
|
||||||
db_type,
|
params.get("db_type", ""),
|
||||||
file_paths,
|
knowledge_base_ids,
|
||||||
params.get("limit", 5),
|
limit,
|
||||||
params.get("offset", 0),
|
params.get("offset", 0),
|
||||||
params.get("use_rerank", True)
|
params.get("use_rerank", True)
|
||||||
)
|
)
|
||||||
elif action == "search_query":
|
elif action == "search_query":
|
||||||
query = params.get("query", "")
|
query = params.get("query", "")
|
||||||
userid = params.get("userid", "")
|
userid = params.get("userid", "")
|
||||||
file_paths = params.get("file_paths", [])
|
limit = params.get("limit", "")
|
||||||
if not query or not userid or not file_paths:
|
knowledge_base_ids = params.get("knowledge_base_ids", [])
|
||||||
return {"status": "error", "message": "query、userid 或 file_paths 不能为空",
|
if not query or not userid or not knowledge_base_ids:
|
||||||
|
return {"status": "error", "message": "query、userid 或 knowledge_base_ids 不能为空",
|
||||||
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
"collection_name": collection_name, "document_id": "", "status_code": 400}
|
||||||
return await self._search_query(
|
return await self._search_query(
|
||||||
query,
|
query,
|
||||||
userid,
|
userid,
|
||||||
db_type,
|
db_type,
|
||||||
file_paths,
|
knowledge_base_ids,
|
||||||
params.get("limit", 5),
|
limit,
|
||||||
params.get("offset", 0),
|
params.get("offset", 0),
|
||||||
params.get("use_rerank", True)
|
params.get("use_rerank", True)
|
||||||
)
|
)
|
||||||
@ -280,7 +295,7 @@ class MilvusConnection:
|
|||||||
error(f"创建集合失败: {str(e)}")
|
error(f"创建集合失败: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"collection_name": collection_name,
|
"collection_name":collection_name,
|
||||||
"message": str(e)
|
"message": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -653,7 +668,6 @@ class MilvusConnection:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
collection = Collection(collection_name)
|
collection = Collection(collection_name)
|
||||||
await asyncio.wait_for(collection.load_async(), timeout=10.0)
|
|
||||||
debug(f"加载集合: {collection_name}")
|
debug(f"加载集合: {collection_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error(f"加载集合 {collection_name} 失败: {str(e)}")
|
error(f"加载集合 {collection_name} 失败: {str(e)}")
|
||||||
@ -942,22 +956,22 @@ class MilvusConnection:
|
|||||||
error(f"重排序服务调用失败: {str(e)}\n{exception()}")
|
error(f"重排序服务调用失败: {str(e)}\n{exception()}")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _fused_search(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5,
|
async def _fused_search(self, query: str, userid: str, db_type: str, knowledge_base_ids: List[str], limit: int = 5,
|
||||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
||||||
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
|
"""融合搜索,将查询与所有三元组拼接后向量化搜索"""
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
info(
|
info(
|
||||||
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
f"开始融合搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||||
|
|
||||||
if not query or not userid or not file_paths:
|
if not query or not userid or not knowledge_base_ids:
|
||||||
raise ValueError("query、userid 和 file_paths 不能为空")
|
raise ValueError("query、userid 和 knowledge_base_ids 不能为空")
|
||||||
if "_" in userid or (db_type and "_" in db_type):
|
if "_" in userid or (db_type and "_" in db_type):
|
||||||
raise ValueError("userid 和 db_type 不能包含下划线")
|
raise ValueError("userid 和 db_type 不能包含下划线")
|
||||||
if (db_type and len(db_type) > 100) or len(userid) > 100:
|
if (db_type and len(db_type) > 100) or len(userid) > 100:
|
||||||
raise ValueError("db_type 或 userid 的长度超出限制")
|
raise ValueError("db_type 或 userid 的长度超出限制")
|
||||||
if limit < 1 or offset < 0:
|
if limit < 1 or limit > 16384 or offset < 0:
|
||||||
raise ValueError("limit 必须大于 0,offset 必须大于或等于 0")
|
raise ValueError("limit 必须在 1 到 16384 之间,offset 必须大于或等于 0")
|
||||||
|
|
||||||
if not utility.has_collection(collection_name):
|
if not utility.has_collection(collection_name):
|
||||||
debug(f"集合 {collection_name} 不存在")
|
debug(f"集合 {collection_name} 不存在")
|
||||||
@ -976,24 +990,24 @@ class MilvusConnection:
|
|||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
all_triplets = []
|
all_triplets = []
|
||||||
for file_path in file_paths:
|
for kb_id in knowledge_base_ids:
|
||||||
filename = os.path.basename(file_path)
|
debug(f"处理知识库: {kb_id}")
|
||||||
debug(f"处理文件: {filename}")
|
|
||||||
|
|
||||||
results = collection.query(
|
results = collection.query(
|
||||||
expr=f"userid == '{userid}' and filename == '{filename}'",
|
expr=f"userid == '{userid}' and knowledge_base_id == '{kb_id}'",
|
||||||
output_fields=["document_id", "filename"],
|
output_fields=["document_id", "filename", "knowledge_base_id"],
|
||||||
limit=1
|
limit=100 # 查询足够多的文档以支持后续过滤
|
||||||
)
|
)
|
||||||
if not results:
|
if not results:
|
||||||
debug(f"未找到 userid {userid} 和 filename {filename} 对应的文档")
|
debug(f"未找到 userid {userid} 和 knowledge_base_id {kb_id} 对应的文档")
|
||||||
continue
|
continue
|
||||||
documents.append(results[0])
|
documents.extend(results)
|
||||||
|
|
||||||
document_id = results[0]["document_id"]
|
for doc in results:
|
||||||
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
|
document_id = doc["document_id"]
|
||||||
debug(f"文件 {filename} 匹配三元组: {len(matched_triplets)} 条")
|
matched_triplets = await self._match_triplets(query, query_entities, userid, document_id)
|
||||||
all_triplets.extend(matched_triplets)
|
debug(f"知识库 {kb_id} 文档 {doc['filename']} 匹配三元组: {len(matched_triplets)} 条")
|
||||||
|
all_triplets.extend(matched_triplets)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
debug("未找到任何有效文档")
|
debug("未找到任何有效文档")
|
||||||
@ -1022,9 +1036,8 @@ class MilvusConnection:
|
|||||||
|
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
filenames = [os.path.basename(file_path) for file_path in file_paths]
|
kb_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||||
filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
|
expr = f"userid == '{userid}' and ({kb_expr})"
|
||||||
expr = f"userid == '{userid}' and ({filename_expr})"
|
|
||||||
debug(f"搜索表达式: {expr}")
|
debug(f"搜索表达式: {expr}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1032,7 +1045,7 @@ class MilvusConnection:
|
|||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
anns_field="vector",
|
anns_field="vector",
|
||||||
param=search_params,
|
param=search_params,
|
||||||
limit=limit,
|
limit=100,
|
||||||
expr=expr,
|
expr=expr,
|
||||||
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
|
||||||
"file_type"],
|
"file_type"],
|
||||||
@ -1073,7 +1086,7 @@ class MilvusConnection:
|
|||||||
|
|
||||||
if use_rerank and unique_results:
|
if use_rerank and unique_results:
|
||||||
debug("开始重排序")
|
debug("开始重排序")
|
||||||
unique_results = await self._rerank_results(combined_text, unique_results, 5)
|
unique_results = await self._rerank_results(combined_text, unique_results, limit) # 使用传入的 limit
|
||||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||||
else:
|
else:
|
||||||
@ -1086,13 +1099,13 @@ class MilvusConnection:
|
|||||||
error(f"融合搜索失败: {str(e)}\n{exception()}")
|
error(f"融合搜索失败: {str(e)}\n{exception()}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _search_query(self, query: str, userid: str, db_type: str, file_paths: List[str], limit: int = 5,
|
async def _search_query(self, query: str, userid: str, db_type: str = "", knowledge_base_ids: List[str] = [], limit: int = 5,
|
||||||
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
offset: int = 0, use_rerank: bool = True) -> List[Dict]:
|
||||||
"""纯向量搜索,基于查询文本在指定文档中搜索相关文本块"""
|
"""纯向量搜索,基于查询文本在指定知识库中搜索相关文本块"""
|
||||||
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
collection_name = "ragdb" if not db_type else f"ragdb_{db_type}"
|
||||||
try:
|
try:
|
||||||
info(
|
info(
|
||||||
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, file_paths={file_paths}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
f"开始纯向量搜索: query={query}, userid={userid}, db_type={db_type}, knowledge_base_ids={knowledge_base_ids}, limit={limit}, offset={offset}, use_rerank={use_rerank}")
|
||||||
|
|
||||||
if not query:
|
if not query:
|
||||||
raise ValueError("查询文本不能为空")
|
raise ValueError("查询文本不能为空")
|
||||||
@ -1108,15 +1121,15 @@ class MilvusConnection:
|
|||||||
raise ValueError("offset 不能为负数")
|
raise ValueError("offset 不能为负数")
|
||||||
if limit + offset > 16384:
|
if limit + offset > 16384:
|
||||||
raise ValueError("limit + offset 不能超过 16384")
|
raise ValueError("limit + offset 不能超过 16384")
|
||||||
if not file_paths:
|
if not knowledge_base_ids:
|
||||||
raise ValueError("file_paths 不能为空")
|
raise ValueError("knowledge_base_ids 不能为空")
|
||||||
for file_path in file_paths:
|
for kb_id in knowledge_base_ids:
|
||||||
if not isinstance(file_path, str):
|
if not isinstance(kb_id, str):
|
||||||
raise ValueError(f"file_path 必须是字符串: {file_path}")
|
raise ValueError(f"knowledge_base_id 必须是字符串: {kb_id}")
|
||||||
if len(os.path.basename(file_path)) > 255:
|
if len(kb_id) > 100:
|
||||||
raise ValueError(f"文件名长度超出 255 个字符: {file_path}")
|
raise ValueError(f"knowledge_base_id 长度超出 100 个字符: {kb_id}")
|
||||||
if "_" in os.path.basename(file_path):
|
if "_" in kb_id:
|
||||||
raise ValueError(f"文件名 {file_path} 不能包含下划线")
|
raise ValueError(f"knowledge_base_id 不能包含下划线: {kb_id}")
|
||||||
|
|
||||||
if not utility.has_collection(collection_name):
|
if not utility.has_collection(collection_name):
|
||||||
debug(f"集合 {collection_name} 不存在")
|
debug(f"集合 {collection_name} 不存在")
|
||||||
@ -1136,9 +1149,8 @@ class MilvusConnection:
|
|||||||
|
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
||||||
|
|
||||||
filenames = [os.path.basename(file_path) for file_path in file_paths]
|
kb_id_expr = " or ".join([f"knowledge_base_id == '{kb_id}'" for kb_id in knowledge_base_ids])
|
||||||
filename_expr = " or ".join([f"filename == '{filename}'" for filename in filenames])
|
expr = f"userid == '{userid}' and ({kb_id_expr})"
|
||||||
expr = f"userid == '{userid}' and ({filename_expr})"
|
|
||||||
debug(f"搜索表达式: {expr}")
|
debug(f"搜索表达式: {expr}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1146,7 +1158,7 @@ class MilvusConnection:
|
|||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
anns_field="vector",
|
anns_field="vector",
|
||||||
param=search_params,
|
param=search_params,
|
||||||
limit=limit,
|
limit=100,
|
||||||
expr=expr,
|
expr=expr,
|
||||||
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
|
output_fields=["text", "userid", "document_id", "filename", "file_path", "upload_time",
|
||||||
"file_type"],
|
"file_type"],
|
||||||
@ -1187,7 +1199,7 @@ class MilvusConnection:
|
|||||||
|
|
||||||
if use_rerank and unique_results:
|
if use_rerank and unique_results:
|
||||||
debug("开始重排序")
|
debug("开始重排序")
|
||||||
unique_results = await self._rerank_results(query, unique_results, 5)
|
unique_results = await self._rerank_results(query, unique_results, limit)
|
||||||
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
unique_results = sorted(unique_results, key=lambda x: x.get('rerank_score', 0), reverse=True)
|
||||||
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
debug(f"重排序分数分布: {[round(r.get('rerank_score', 0), 3) for r in unique_results]}")
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user