bugfix
This commit is contained in:
parent
fcaee5e657
commit
e5dae40364
@ -20,8 +20,11 @@ class BaseEmbedding:
|
|||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def embeddings(self, input):
|
def encode(self, input):
|
||||||
es = self.model.encode(input)
|
es = self.model.encode(input)
|
||||||
|
|
||||||
|
def embeddings(self, input):
|
||||||
|
es = self.encode(input)
|
||||||
data = []
|
data = []
|
||||||
for i, e in enumerate(es):
|
for i, e in enumerate(es):
|
||||||
d = {
|
d = {
|
||||||
|
24
llmengine/bgeembedding.py
Normal file
24
llmengine/bgeembedding.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
from llmengine.base_embedding import BaseEmbedding, llm_register
|
||||||
|
|
||||||
|
class BgeEmbedding(BaseEmbedding):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
self.model = BGEM3FlagModel(model_id, use_fp16=True)
|
||||||
|
# Setting use_fp16 to True speeds up computation with a slight performance degradation
|
||||||
|
self.kwargs = {
|
||||||
|
"batch_size": 12,
|
||||||
|
"max_length": 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode(self, input):
|
||||||
|
ret = []
|
||||||
|
for t in input:
|
||||||
|
embedding = model.encode(sentences_1, **self.kwargs)['dense_vecs']
|
||||||
|
ret.append(embedding
|
||||||
|
return ret
|
||||||
|
|
||||||
|
llm_register('bge-m3', BgeEmbedding)
|
@ -3,6 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
from llmengine.qwen3embedding import *
|
from llmengine.qwen3embedding import *
|
||||||
|
from llmengine.bgeembedding import *
|
||||||
from llmengine.base_embedding import get_llm_class
|
from llmengine.base_embedding import get_llm_class
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
@ -1,23 +1,3 @@
|
|||||||
[project]
|
|
||||||
name="llmengine"
|
|
||||||
version = "0.0.1"
|
|
||||||
description = "Your project description"
|
|
||||||
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.8"
|
|
||||||
license = {text = "MIT"}
|
|
||||||
dependencies = [
|
|
||||||
"torch",
|
|
||||||
"transformers",
|
|
||||||
"sentence-transformers>=2.7.0",
|
|
||||||
# "flash_attention_2",
|
|
||||||
"mistral-common",
|
|
||||||
"accelerate"
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = ["pytest", "black", "mypy"]
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61", "wheel"]
|
requires = ["setuptools>=61", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
19
setup.cfg
Normal file
19
setup.cfg
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
[metadata]
|
||||||
|
name=llmengine
|
||||||
|
version = 0.0.2
|
||||||
|
description = A transformers base reference engine
|
||||||
|
author = "yu moqing"
|
||||||
|
author_email = "yumoqing@gmail.com"
|
||||||
|
readme = "README.md"
|
||||||
|
license = "MIT"
|
||||||
|
[options]
|
||||||
|
packages = find:
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
install_requires =
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
sentence-transformers>=2.7.0
|
||||||
|
# flash_attention_2
|
||||||
|
mistral-common
|
||||||
|
accelerate
|
||||||
|
FlagEmbedding
|
Loading…
Reference in New Issue
Block a user