This commit is contained in:
ymq1 2025-06-24 18:29:29 +08:00
parent 3d876594d1
commit 52e9ea7b34
3 changed files with 18 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import torch import torch
import numpy as np
model_pathMap = { model_pathMap = {
} }
@ -27,10 +28,12 @@ class BaseEmbedding:
es = self.encode(input) es = self.encode(input)
data = [] data = []
for i, e in enumerate(es): for i, e in enumerate(es):
if isinstance(e, np.ndarray):
r = e.tolist()
d = { d = {
"object": "embedding", "object": "embedding",
"index": i, "index": i,
"embedding": e.tolist() "embedding": e
} }
data.append(d) data.append(d)
return { return {

View File

@ -1,24 +1,27 @@
import torch
from FlagEmbedding import BGEM3FlagModel from langchain_huggingface import HuggingFaceEmbeddings
from llmengine.base_embedding import BaseEmbedding, llm_register from llmengine.base_embedding import BaseEmbedding, llm_register
class BgeEmbedding(BaseEmbedding): class BgeEmbedding(BaseEmbedding):
def __init__(self, model_id): def __init__(self, model_id):
self.model_id = model_id self.model_id = model_id
self.model_name = model_id.split('/')[-1] self.model_name = model_id.split('/')[-1]
self.model = BGEM3FlagModel(model_id, use_fp16=True) self.model = HuggingFaceEmbeddings(
# Setting use_fp16 to True speeds up computation with a slight performance degradation model_name=model_id,
self.kwargs = { model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
"batch_size": 12, encode_kwargs={
"max_length": 8192 "batch_size": 12,
} "max_length": 8192,
'normalize_embeddings': True
}
)
def encode(self, input): def encode(self, input):
ret = [] ret = []
for t in input: for t in input:
embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs'] embedding = self.model.embed_query(t)
ret.append(embedding ret.append(embedding)
return ret return ret
llm_register('bge-m3', BgeEmbedding) llm_register('bge-m3', BgeEmbedding)

View File

@ -16,4 +16,4 @@ install_requires =
# flash_attention_2 # flash_attention_2
mistral-common mistral-common
accelerate accelerate
FlagEmbedding langchain_huggingface