From 52e9ea7b34b04d69447bd3ee8659350197c504eb Mon Sep 17 00:00:00 2001 From: ymq1 Date: Tue, 24 Jun 2025 18:29:29 +0800 Subject: [PATCH] bugfix --- llmengine/base_embedding.py | 5 ++++- llmengine/bgeembedding.py | 23 +++++++++++++---------- setup.cfg | 2 +- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/llmengine/base_embedding.py b/llmengine/base_embedding.py index 938b8a8..6977407 100644 --- a/llmengine/base_embedding.py +++ b/llmengine/base_embedding.py @@ -1,4 +1,5 @@ import torch +import numpy as np model_pathMap = { } @@ -27,10 +28,12 @@ class BaseEmbedding: es = self.encode(input) data = [] for i, e in enumerate(es): + if isinstance(e, np.ndarray): + r = e.tolist() d = { "object": "embedding", "index": i, - "embedding": e.tolist() + "embedding": e } data.append(d) return { diff --git a/llmengine/bgeembedding.py b/llmengine/bgeembedding.py index 1e80189..a4f278d 100644 --- a/llmengine/bgeembedding.py +++ b/llmengine/bgeembedding.py @@ -1,24 +1,27 @@ - -from FlagEmbedding import BGEM3FlagModel +import torch +from langchain_huggingface import HuggingFaceEmbeddings from llmengine.base_embedding import BaseEmbedding, llm_register class BgeEmbedding(BaseEmbedding): def __init__(self, model_id): self.model_id = model_id self.model_name = model_id.split('/')[-1] - self.model = BGEM3FlagModel(model_id, use_fp16=True) - # Setting use_fp16 to True speeds up computation with a slight performance degradation - self.kwargs = { - "batch_size": 12, - "max_length": 8192 - } + self.model = HuggingFaceEmbeddings( + model_name=model_id, + model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, + encode_kwargs={ + "batch_size": 12, + "max_length": 8192, + 'normalize_embeddings': True + } + ) def encode(self, input): ret = [] for t in input: - embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs'] - ret.append(embedding + embedding = self.model.embed_query(t) + ret.append(embedding) return ret llm_register('bge-m3', BgeEmbedding) diff --git a/setup.cfg b/setup.cfg index 1fa22ba..ec48d1b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,4 +16,4 @@ install_requires = # flash_attention_2 mistral-common accelerate - FlagEmbedding + langchain_huggingface