This commit is contained in:
ymq1 2025-06-24 18:29:29 +08:00
parent 3d876594d1
commit 52e9ea7b34
3 changed files with 18 additions and 12 deletions

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
model_pathMap = {
}
@ -27,10 +28,12 @@ class BaseEmbedding:
es = self.encode(input)
data = []
for i, e in enumerate(es):
if isinstance(e, np.ndarray):
r = e.tolist()
d = {
"object": "embedding",
"index": i,
"embedding": e.tolist()
"embedding": e
}
data.append(d)
return {

View File

@ -1,24 +1,27 @@
from FlagEmbedding import BGEM3FlagModel
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from llmengine.base_embedding import BaseEmbedding, llm_register
class BgeEmbedding(BaseEmbedding):
def __init__(self, model_id):
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
self.model = BGEM3FlagModel(model_id, use_fp16=True)
# Setting use_fp16 to True speeds up computation with a slight performance degradation
self.kwargs = {
"batch_size": 12,
"max_length": 8192
}
self.model = HuggingFaceEmbeddings(
model_name=model_id,
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={
"batch_size": 12,
"max_length": 8192,
'normalize_embeddings': True
}
)
def encode(self, input):
ret = []
for t in input:
embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs']
ret.append(embedding
embedding = self.model.embed_query(t)
ret.append(embedding)
return ret
llm_register('bge-m3', BgeEmbedding)

View File

@ -16,4 +16,4 @@ install_requires =
# flash_attention_2
mistral-common
accelerate
FlagEmbedding
langchain_huggingface