bugfix
This commit is contained in:
parent
3d876594d1
commit
52e9ea7b34
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
model_pathMap = {
|
model_pathMap = {
|
||||||
}
|
}
|
||||||
@ -27,10 +28,12 @@ class BaseEmbedding:
|
|||||||
es = self.encode(input)
|
es = self.encode(input)
|
||||||
data = []
|
data = []
|
||||||
for i, e in enumerate(es):
|
for i, e in enumerate(es):
|
||||||
|
if isinstance(e, np.ndarray):
|
||||||
|
r = e.tolist()
|
||||||
d = {
|
d = {
|
||||||
"object": "embedding",
|
"object": "embedding",
|
||||||
"index": i,
|
"index": i,
|
||||||
"embedding": e.tolist()
|
"embedding": e
|
||||||
}
|
}
|
||||||
data.append(d)
|
data.append(d)
|
||||||
return {
|
return {
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
from llmengine.base_embedding import BaseEmbedding, llm_register
|
from llmengine.base_embedding import BaseEmbedding, llm_register
|
||||||
|
|
||||||
class BgeEmbedding(BaseEmbedding):
|
class BgeEmbedding(BaseEmbedding):
|
||||||
def __init__(self, model_id):
|
def __init__(self, model_id):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model_name = model_id.split('/')[-1]
|
self.model_name = model_id.split('/')[-1]
|
||||||
self.model = BGEM3FlagModel(model_id, use_fp16=True)
|
self.model = HuggingFaceEmbeddings(
|
||||||
# Setting use_fp16 to True speeds up computation with a slight performance degradation
|
model_name=model_id,
|
||||||
self.kwargs = {
|
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
|
||||||
"batch_size": 12,
|
encode_kwargs={
|
||||||
"max_length": 8192
|
"batch_size": 12,
|
||||||
}
|
"max_length": 8192,
|
||||||
|
'normalize_embeddings': True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def encode(self, input):
|
def encode(self, input):
|
||||||
ret = []
|
ret = []
|
||||||
for t in input:
|
for t in input:
|
||||||
embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs']
|
embedding = self.model.embed_query(t)
|
||||||
ret.append(embedding
|
ret.append(embedding)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
llm_register('bge-m3', BgeEmbedding)
|
llm_register('bge-m3', BgeEmbedding)
|
||||||
|
Loading…
Reference in New Issue
Block a user