llmengine/llmengine/collection.py

77 lines
2.0 KiB
Python

from traceback import format_exc
import os
import sys
import argparse
from llmengine.milvus_collection import *
from llmengine.base_collection import get_collection_class
from typing import Dict
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify
from appPublic.log import debug, exception
from ahserver.serverenv import ServerEnv
from ahserver.globalEnv import stream_response
from ahserver.webapp import webserver
helptext = """Milvus Collection Creation API:
1. Create Collection Endpoint:
path: /v1/collections
headers: {
"Content-Type": "application/json"
}
data: {
"db_type": "textdb"
}
response: {
"status": "success",
"collection_name": "ragdb_textdb",
"message": "集合 ragdb_textdb 创建成功"
}
2. Docs Endpoint:
path: /v1/docs
response: This help text
"""
def init():
rf = RegisterFunction()
rf.register('collections', create_collection)
rf.register('docs', docs)
async def docs(request, params_kw, *params, **kw):
return helptext
async def create_collection(request, params_kw, *params, **kw):
debug(f'{params_kw=}')
se = ServerEnv()
engine = se.engine
f = awaitify(engine.create_collection)
db_type = params_kw.get('db_type')
if db_type is None:
e = exception(f'db_type is None')
raise e
result = await f(db_type)
debug(f'{result=}')
return result
def main():
parser = argparse.ArgumentParser(prog="Milvus Collection Service")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
parser.add_argument('model_path')
args = parser.parse_args()
Klass = get_collection_class(args.model_path)
if Klass is None:
e = Exception(f'{args.model_path} has not mapping to a model class')
exception(f'{e}, {format_exc()}')
raise e
se = ServerEnv()
se.engine = Klass(args.model_path)
workdir = args.workdir or os.getcwd()
port = args.port
debug(f'{args=}')
webserver(init, workdir, port)
if __name__ == '__main__':
main()