31 lines
651 B
Python
31 lines
651 B
Python
import torch
|
|
|
|
model_pathMap = {
|
|
}
|
|
def llm_register(model_key, Klass):
|
|
global model_pathMap
|
|
model_pathMap[model_key] = Klass
|
|
|
|
def get_llm_class(model_path):
|
|
for k,klass in model_pathMap.items():
|
|
if len(model_path.split(k)) > 1:
|
|
return klass
|
|
print(f'{model_pathMap=}')
|
|
return None
|
|
|
|
class BaseEmbedding:
|
|
|
|
def use_mps_if_prosible(self):
|
|
if torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
self.model = self.model.to(device)
|
|
|
|
def embedding(self, doc):
|
|
es = self.model.encode([doc])[0]
|
|
return es.tolist()
|
|
|
|
def similarity(self, qvector, dcovectors):
|
|
s = self.model.similarity([qvector], docvectors)
|
|
return s[0]
|
|
|