This commit is contained in:
root 2025-06-24 17:50:49 +08:00
parent fcaee5e657
commit e5dae40364
5 changed files with 48 additions and 21 deletions

View File

@ -20,8 +20,11 @@ class BaseEmbedding:
device = torch.device("mps")
self.model = self.model.to(device)
def embeddings(self, input):
def encode(self, input):
es = self.model.encode(input)
def embeddings(self, input):
es = self.encode(input)
data = []
for i, e in enumerate(es):
d = {

24
llmengine/bgeembedding.py Normal file
View File

@ -0,0 +1,24 @@
from FlagEmbedding import BGEM3FlagModel
from llmengine.base_embedding import BaseEmbedding, llm_register
class BgeEmbedding(BaseEmbedding):
def __init__(self, model_id):
self.model_id = model_id
self.model_name = model_id.split('/')[-1]
self.model = BGEM3FlagModel(model_id, use_fp16=True)
# Setting use_fp16 to True speeds up computation with a slight performance degradation
self.kwargs = {
"batch_size": 12,
"max_length": 8192
}
def encode(self, input):
ret = []
for t in input:
embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs']
ret.append(embedding
return ret
llm_register('bge-m3', BgeEmbedding)

View File

@ -3,6 +3,7 @@ import os
import sys
import argparse
from llmengine.qwen3embedding import *
from llmengine.bgeembedding import *
from llmengine.base_embedding import get_llm_class
from appPublic.registerfunction import RegisterFunction

View File

@ -1,23 +1,3 @@
[project]
name="llmengine"
version = "0.0.1"
description = "Your project description"
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
dependencies = [
"torch",
"transformers",
"sentence-transformers>=2.7.0",
# "flash_attention_2",
"mistral-common",
"accelerate"
]
[project.optional-dependencies]
dev = ["pytest", "black", "mypy"]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"

19
setup.cfg Normal file
View File

@ -0,0 +1,19 @@
[metadata]
name=llmengine
version = 0.0.2
description = A transformers base reference engine
author = "yu moqing"
author_email = "yumoqing@gmail.com"
readme = "README.md"
license = "MIT"
[options]
packages = find:
requires-python = ">=3.8"
install_requires =
torch
transformers
sentence-transformers>=2.7.0
# flash_attention_2
mistral-common
accelerate
FlagEmbedding