diff --git a/llmengine/base_reranker.py b/llmengine/base_reranker.py index 5938faf..892ec91 100644 --- a/llmengine/base_reranker.py +++ b/llmengine/base_reranker.py @@ -1,5 +1,17 @@ import torch +model_pathMap = { +} +def llm_register(model_key, Klass): + model_pathMap[model_key] = Klass + +def get_llm_class(model_path): + for k,klass in model_pathMap.items(): + if len(model_path.split(k)) > 1: + return klass + print(f'{model_pathMap=}') + return None + class BaseReranker: def __init__(self, model_id, **kw): self.model_id = model_id