llmengine/llmengine/base_embedding.py
2025-06-20 18:56:30 +08:00

31 lines
651 B
Python

import torch
model_pathMap = {
}
def llm_register(model_key, Klass):
global model_pathMap
model_pathMap[model_key] = Klass
def get_llm_class(model_path):
for k,klass in model_pathMap.items():
if len(model_path.split(k)) > 1:
return klass
print(f'{model_pathMap=}')
return None
class BaseEmbedding:
def use_mps_if_prosible(self):
if torch.backends.mps.is_available():
device = torch.device("mps")
self.model = self.model.to(device)
def embedding(self, doc):
es = self.model.encode([doc])[0]
return es.tolist()
def similarity(self, qvector, dcovectors):
s = self.model.similarity([qvector], docvectors)
return s[0]