diff --git a/base_embedding.py b/base_embedding.py new file mode 100644 index 0000000..342945c --- /dev/null +++ b/base_embedding.py @@ -0,0 +1,11 @@ + +class BaseEmbedding: + + def embedding(self, doc): + es = self.model.encode([doc]) + return es[0] + + def similarity(self, qvector, dcovectors): + s = self.model.similarity([qvector], docvectors) + return s[0] + diff --git a/base_reranker.py b/base_reranker.py new file mode 100644 index 0000000..6c71e7e --- /dev/null +++ b/base_reranker.py @@ -0,0 +1,40 @@ + +import torch + +classs BaseReranker: + + def process_input(self, pairs): + inputs = self.tokenizer( + pairs, padding=False, truncation='longest_first', + return_attention_mask=False, max_length=self.max_length + ) + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length) + for key in inputs: + inputs[key] = inputs[key].to(self.model.device) + return inputs + + def build_sys_prompt(self, sys_prompt): + return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>" + + def build_user_prompt(self, query, docs, instruct=''): + return f'<|im_start|>user\n: {instruct}\n:{query}\n:\n{doc}<|im_end|>' + + def build_assistant_prompt(self): + return "<|im_start|>assistant\n\n\n\n\n" + + def compute_logits(self, inputs, **kwargs): + batch_scores = self.model(**inputs).logits[:, -1, :] + # true_vector = batch_scores[:, token_true_id] + # false_vector = batch_scores[:, token_false_id] + # batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def rerank(self, query, docs, sys_prompt="", task=""): + sys_str = self.build_sys_prompt(sys_prompt) + ass_str = self.build_assistant_prompt() + pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ] + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) + diff --git a/qwen3_reranker.py b/qwen3_reranker.py new file mode 100644 index 0000000..1232536 --- /dev/null +++ b/qwen3_reranker.py @@ -0,0 +1,10 @@ +import torch +from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM +from llmengine.base_reranker import BaseReranker + +class Qwen3Reranker(BaseReranker): + def __init__(self, model_id, max_length=8096): + self.odel_id = model_id + self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left') + self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval() + diff --git a/qwen3embedding.py b/qwen3embedding.py new file mode 100644 index 0000000..ecc978d --- /dev/null +++ b/qwen3embedding.py @@ -0,0 +1,19 @@ +# Requires transformers>=4.51.0 +# Requires sentence-transformers>=2.7.0 + +from sentence_transformers import SentenceTransformer +from llmengine.base_embedding import BaseEmbedding + +class Qwen3Embedding(BaseEmbedding): + def __init__(self, model_id, max_length=8096): + # Load the model + self.model = SentenceTransformer(model_id) + # We recommend enabling flash_attention_2 for better acceleration and memory saving, + # together with setting `padding_side` to "left": + # model = SentenceTransformer( + # "Qwen/Qwen3-Embedding-0.6B", + # model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto"}, + # tokenizer_kwargs={"padding_side": "left"}, + # ) + self.max_length = max_length +