bugfix
This commit is contained in:
parent
8f8502cb9e
commit
58c4af5c8b
11
base_embedding.py
Normal file
11
base_embedding.py
Normal file
@ -0,0 +1,11 @@
|
||||
|
||||
class BaseEmbedding:
|
||||
|
||||
def embedding(self, doc):
|
||||
es = self.model.encode([doc])
|
||||
return es[0]
|
||||
|
||||
def similarity(self, qvector, dcovectors):
|
||||
s = self.model.similarity([qvector], docvectors)
|
||||
return s[0]
|
||||
|
40
base_reranker.py
Normal file
40
base_reranker.py
Normal file
@ -0,0 +1,40 @@
|
||||
|
||||
import torch
|
||||
|
||||
classs BaseReranker:
|
||||
|
||||
def process_input(self, pairs):
|
||||
inputs = self.tokenizer(
|
||||
pairs, padding=False, truncation='longest_first',
|
||||
return_attention_mask=False, max_length=self.max_length
|
||||
)
|
||||
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
|
||||
def build_sys_prompt(self, sys_prompt):
|
||||
return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>"
|
||||
|
||||
def build_user_prompt(self, query, docs, instruct=''):
|
||||
return f'<|im_start|>user\n<Instruct>: {instruct}\n<Query>:{query}\n<Document>:\n{doc}<|im_end|>'
|
||||
|
||||
def build_assistant_prompt(self):
|
||||
return "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
def compute_logits(self, inputs, **kwargs):
|
||||
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||
# true_vector = batch_scores[:, token_true_id]
|
||||
# false_vector = batch_scores[:, token_false_id]
|
||||
# batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
def rerank(self, query, docs, sys_prompt="", task=""):
|
||||
sys_str = self.build_sys_prompt(sys_prompt)
|
||||
ass_str = self.build_assistant_prompt()
|
||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||
inputs = self.process_inputs(pairs)
|
||||
scores = self.compute_logits(inputs)
|
||||
|
10
qwen3_reranker.py
Normal file
10
qwen3_reranker.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
||||
from llmengine.base_reranker import BaseReranker
|
||||
|
||||
class Qwen3Reranker(BaseReranker):
|
||||
def __init__(self, model_id, max_length=8096):
|
||||
self.odel_id = model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
||||
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
|
||||
|
19
qwen3embedding.py
Normal file
19
qwen3embedding.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Requires transformers>=4.51.0
|
||||
# Requires sentence-transformers>=2.7.0
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from llmengine.base_embedding import BaseEmbedding
|
||||
|
||||
class Qwen3Embedding(BaseEmbedding):
|
||||
def __init__(self, model_id, max_length=8096):
|
||||
# Load the model
|
||||
self.model = SentenceTransformer(model_id)
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving,
|
||||
# together with setting `padding_side` to "left":
|
||||
# model = SentenceTransformer(
|
||||
# "Qwen/Qwen3-Embedding-0.6B",
|
||||
# model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto"},
|
||||
# tokenizer_kwargs={"padding_side": "left"},
|
||||
# )
|
||||
self.max_length = max_length
|
||||
|
Loading…
Reference in New Issue
Block a user