From e5dae403648a238339807a7d8e3d4d7479101d2b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 24 Jun 2025 17:50:49 +0800 Subject: [PATCH] bugfix --- llmengine/base_embedding.py | 5 ++++- llmengine/bgeembedding.py | 24 ++++++++++++++++++++++++ llmengine/embedding.py | 1 + pyproject.toml | 20 -------------------- setup.cfg | 19 +++++++++++++++++++ 5 files changed, 48 insertions(+), 21 deletions(-) create mode 100644 llmengine/bgeembedding.py create mode 100644 setup.cfg diff --git a/llmengine/base_embedding.py b/llmengine/base_embedding.py index 68c6ba6..938b8a8 100644 --- a/llmengine/base_embedding.py +++ b/llmengine/base_embedding.py @@ -20,8 +20,11 @@ class BaseEmbedding: device = torch.device("mps") self.model = self.model.to(device) - def embeddings(self, input): + def encode(self, input): es = self.model.encode(input) + + def embeddings(self, input): + es = self.encode(input) data = [] for i, e in enumerate(es): d = { diff --git a/llmengine/bgeembedding.py b/llmengine/bgeembedding.py new file mode 100644 index 0000000..1e80189 --- /dev/null +++ b/llmengine/bgeembedding.py @@ -0,0 +1,24 @@ + + +from FlagEmbedding import BGEM3FlagModel +from llmengine.base_embedding import BaseEmbedding, llm_register + +class BgeEmbedding(BaseEmbedding): + def __init__(self, model_id): + self.model_id = model_id + self.model_name = model_id.split('/')[-1] + self.model = BGEM3FlagModel(model_id, use_fp16=True) + # Setting use_fp16 to True speeds up computation with a slight performance degradation + self.kwargs = { + "batch_size": 12, + "max_length": 8192 + } + + def encode(self, input): + ret = [] + for t in input: + embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs'] + ret.append(embedding + return ret + +llm_register('bge-m3', BgeEmbedding) diff --git a/llmengine/embedding.py b/llmengine/embedding.py index a3cf731..87a2be5 100644 --- a/llmengine/embedding.py +++ b/llmengine/embedding.py @@ -3,6 +3,7 @@ import os import sys import argparse from llmengine.qwen3embedding import * +from llmengine.bgeembedding import * from llmengine.base_embedding import get_llm_class from appPublic.registerfunction import RegisterFunction diff --git a/pyproject.toml b/pyproject.toml index 73d37f7..59514a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,3 @@ -[project] -name="llmengine" -version = "0.0.1" -description = "Your project description" -authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }] -readme = "README.md" -requires-python = ">=3.8" -license = {text = "MIT"} -dependencies = [ - "torch", - "transformers", - "sentence-transformers>=2.7.0", - # "flash_attention_2", - "mistral-common", - "accelerate" -] - -[project.optional-dependencies] -dev = ["pytest", "black", "mypy"] - [build-system] requires = ["setuptools>=61", "wheel"] build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1fa22ba --- /dev/null +++ b/setup.cfg @@ -0,0 +1,19 @@ +[metadata] +name=llmengine +version = 0.0.2 +description = A transformers base reference engine +author = "yu moqing" +author_email = "yumoqing@gmail.com" +readme = "README.md" +license = "MIT" +[options] +packages = find: +requires-python = ">=3.8" +install_requires = + torch + transformers + sentence-transformers>=2.7.0 + # flash_attention_2 + mistral-common + accelerate + FlagEmbedding