bugfix
This commit is contained in:
parent
a10de4c6c4
commit
3d876594d1
@ -1,5 +0,0 @@
|
|||||||
from ahserver.configuredServer import ConfiguredServer
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
server = ConfiguredServer()
|
|
||||||
server.run()
|
|
@ -1,246 +0,0 @@
|
|||||||
import threading
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
from time import time
|
|
||||||
from transformers import TextIteratorStreamer
|
|
||||||
from appPublic.log import debug
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.uniqueID import getID
|
|
||||||
|
|
||||||
model_pathMap = {
|
|
||||||
}
|
|
||||||
def llm_register(model_key, Klass):
|
|
||||||
model_pathMap[model_key] = Klass
|
|
||||||
|
|
||||||
def get_llm_class(model_path):
|
|
||||||
for k,klass in model_pathMap.items():
|
|
||||||
if len(model_path.split(k)) > 1:
|
|
||||||
return klass
|
|
||||||
print(f'{model_pathMap=}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
class BaseChatLLM:
|
|
||||||
def use_mps_if_prosible(self):
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
|
|
||||||
def get_session_key(self):
|
|
||||||
return self.model_id + ':messages'
|
|
||||||
|
|
||||||
def _get_session_messages(self, session):
|
|
||||||
key = self.get_session_key()
|
|
||||||
messages = session.get(key) or []
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _set_session_messages(self, session, messages):
|
|
||||||
key = self.get_session_key()
|
|
||||||
session[key] = messages
|
|
||||||
|
|
||||||
def get_streamer(self):
|
|
||||||
return TextIteratorStreamer(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
skip_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def output_generator(self, streamer):
|
|
||||||
all_txt = ''
|
|
||||||
t1 = time()
|
|
||||||
i = 0
|
|
||||||
id = f'chatllm-{getID}'
|
|
||||||
for txt in streamer:
|
|
||||||
if txt == '':
|
|
||||||
continue
|
|
||||||
if i == 0:
|
|
||||||
t2 = time()
|
|
||||||
i += 1
|
|
||||||
all_txt += txt
|
|
||||||
yield {
|
|
||||||
"id":id,
|
|
||||||
"object":"chat.completion.chunk",
|
|
||||||
"created":time(),
|
|
||||||
"model":self.model_id,
|
|
||||||
"choices":[
|
|
||||||
{
|
|
||||||
"index":0,
|
|
||||||
"delta":{
|
|
||||||
"content":txt
|
|
||||||
},
|
|
||||||
"logprobs":None,
|
|
||||||
"finish_reason":None
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
t3 = time()
|
|
||||||
t = all_txt
|
|
||||||
unk = self.tokenizer(t, return_tensors="pt")
|
|
||||||
output_tokens = len(unk["input_ids"][0])
|
|
||||||
yield {
|
|
||||||
"id":id,
|
|
||||||
"object":"chat.completion.chunk",
|
|
||||||
"created":time(),
|
|
||||||
"model":self.model_id,
|
|
||||||
"response_time": t2 - t1,
|
|
||||||
"finish_time": t3 - t1,
|
|
||||||
"output_token": output_tokens,
|
|
||||||
"choices":[
|
|
||||||
{
|
|
||||||
"index":0,
|
|
||||||
"delta":{
|
|
||||||
"content":""
|
|
||||||
},
|
|
||||||
"logprobs":None,
|
|
||||||
"finish_reason":"stop"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
messages = self._get_session_messages(session)
|
|
||||||
if sys_prompt:
|
|
||||||
messages.append(self._build_sys_message(sys_prompt))
|
|
||||||
messages.append(self._build_user_message(prompt, image_path=image_path))
|
|
||||||
# debug(f'{messages=}')
|
|
||||||
all_txt = ''
|
|
||||||
for d in self._gen(messages):
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
messages.append(self._build_assistant_message(all_txt))
|
|
||||||
else:
|
|
||||||
all_txt += d['choices'][0]['delta']['content']
|
|
||||||
yield d
|
|
||||||
self._set_session_messages(session, messages)
|
|
||||||
|
|
||||||
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
yield d
|
|
||||||
|
|
||||||
def generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
return d
|
|
||||||
def stream_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
s = f'data: {json.dumps(d)}\n'
|
|
||||||
yield s
|
|
||||||
|
|
||||||
async def async_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
return d
|
|
||||||
|
|
||||||
async def async_stream_generate(self, session, prompt,
|
|
||||||
image_path=None,
|
|
||||||
video_path=None,
|
|
||||||
audio_path=None,
|
|
||||||
sys_prompt=None):
|
|
||||||
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
|
||||||
s = f'data: {json.dumps(d)}\n'
|
|
||||||
yield s
|
|
||||||
yield 'data: [DONE]'
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
|
||||||
generate_kwargs = dict(
|
|
||||||
**inputs,
|
|
||||||
streamer=streamer,
|
|
||||||
max_new_tokens=512,
|
|
||||||
do_sample=True,
|
|
||||||
eos_token_id=self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
return generate_kwargs
|
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
|
||||||
return self.processor.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True, return_tensors="pt"
|
|
||||||
).to(self.model.device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
def _gen(self, messages):
|
|
||||||
inputs = self._messages2inputs(messages)
|
|
||||||
input_len = inputs["input_ids"].shape[-1]
|
|
||||||
streamer = self.get_streamer()
|
|
||||||
kwargs = self.build_kwargs(inputs, streamer)
|
|
||||||
thread = threading.Thread(target=self.model.generate,
|
|
||||||
kwargs=kwargs)
|
|
||||||
thread.start()
|
|
||||||
for d in self.output_generator(streamer):
|
|
||||||
if d['choices'][0]['finish_reason'] == 'stop':
|
|
||||||
d['input_tokens'] = input_len
|
|
||||||
yield d
|
|
||||||
|
|
||||||
class T2TChatLLM(BaseChatLLM):
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"assistant",
|
|
||||||
"content":prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, **kw):
|
|
||||||
return {
|
|
||||||
"role":"user",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
class MMChatLLM(BaseChatLLM):
|
|
||||||
""" multiple modal chat LLM """
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"assistant",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return {
|
|
||||||
"role":"system",
|
|
||||||
"content":[{"type": "text", "text": prompt}]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, image_path=None,
|
|
||||||
video_path=None, audio_path=None):
|
|
||||||
contents = [
|
|
||||||
{
|
|
||||||
"type":"text", "text": prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if image_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "image",
|
|
||||||
"image": image_path
|
|
||||||
})
|
|
||||||
if video_path:
|
|
||||||
contents.append({
|
|
||||||
"type": "video",
|
|
||||||
"video":video_path
|
|
||||||
})
|
|
||||||
if audio_path:
|
|
||||||
contents.append({
|
|
||||||
"tyoe": "audio",
|
|
||||||
"audio": audio_path
|
|
||||||
})
|
|
||||||
return {
|
|
||||||
"role": "user",
|
|
||||||
"content": contents
|
|
||||||
}
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
model_pathMap = {
|
|
||||||
}
|
|
||||||
def llm_register(model_key, Klass):
|
|
||||||
global model_pathMap
|
|
||||||
model_pathMap[model_key] = Klass
|
|
||||||
|
|
||||||
def get_llm_class(model_path):
|
|
||||||
for k,klass in model_pathMap.items():
|
|
||||||
if len(model_path.split(k)) > 1:
|
|
||||||
return klass
|
|
||||||
print(f'{model_pathMap=}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
class BaseEmbedding:
|
|
||||||
|
|
||||||
def use_mps_if_prosible(self):
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
|
|
||||||
def embeddings(self, input):
|
|
||||||
es = self.model.encode(input)
|
|
||||||
data = []
|
|
||||||
for i, e in enumerate(es):
|
|
||||||
d = {
|
|
||||||
"object": "embedding",
|
|
||||||
"index": i,
|
|
||||||
"embedding": e.tolist()
|
|
||||||
}
|
|
||||||
data.append(d)
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": data,
|
|
||||||
"model": self.model_name,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def similarity(self, qvector, dcovectors):
|
|
||||||
s = self.model.similarity([qvector], docvectors)
|
|
||||||
return s[0]
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
model_pathMap = {}
|
|
||||||
|
|
||||||
def ltp_register(model_key, Klass):
|
|
||||||
"""Register a model class for a given model key."""
|
|
||||||
global model_pathMap
|
|
||||||
model_pathMap[model_key] = Klass
|
|
||||||
|
|
||||||
def get_ltp_class(model_path):
|
|
||||||
"""Find the model class for a given model path."""
|
|
||||||
for k, klass in model_pathMap.items():
|
|
||||||
if len(model_path.split(k)) > 1:
|
|
||||||
return klass
|
|
||||||
print(f'{model_pathMap=}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
class BaseLtp(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def extract_entities(self, query: str) -> List[str]:
|
|
||||||
"""Extract entities from query text."""
|
|
||||||
pass
|
|
@ -1,84 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
model_pathMap = {
|
|
||||||
}
|
|
||||||
def llm_register(model_key, Klass):
|
|
||||||
model_pathMap[model_key] = Klass
|
|
||||||
|
|
||||||
def get_llm_class(model_path):
|
|
||||||
for k,klass in model_pathMap.items():
|
|
||||||
if len(model_path.split(k)) > 1:
|
|
||||||
return klass
|
|
||||||
print(f'{model_pathMap=}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
class BaseReranker:
|
|
||||||
def __init__(self, model_id, **kw):
|
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
def use_mps_if_prosible(self):
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
|
|
||||||
def process_inputs(self, pairs):
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
pairs, padding=False, truncation='longest_first',
|
|
||||||
return_attention_mask=False, max_length=self.max_length
|
|
||||||
)
|
|
||||||
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
|
||||||
for key in inputs:
|
|
||||||
inputs[key] = inputs[key].to(self.model.device)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def build_sys_prompt(self, sys_prompt):
|
|
||||||
return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>"
|
|
||||||
|
|
||||||
def build_user_prompt(self, query, doc, instruct=''):
|
|
||||||
return f'<|im_start|>user\n<Instruct>: {instruct}\n<Query>:{query}\n<Document>:\n{doc}<|im_end|>'
|
|
||||||
|
|
||||||
def build_assistant_prompt(self):
|
|
||||||
return "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
||||||
|
|
||||||
def compute_logits(self, inputs, **kwargs):
|
|
||||||
batch_scores = self.model(**inputs).logits[:, -1, :]
|
|
||||||
# true_vector = batch_scores[:, token_true_id]
|
|
||||||
# false_vector = batch_scores[:, token_false_id]
|
|
||||||
# batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
|
||||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
|
||||||
scores = batch_scores[:, 1].exp().tolist()
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def build_pairs(self, query, docs, sys_prompt="", task=""):
|
|
||||||
sys_str = self.build_sys_prompt(sys_prompt)
|
|
||||||
ass_str = self.build_assistant_prompt()
|
|
||||||
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def rerank(self, query, docs, top_n, sys_prompt="", task=""):
|
|
||||||
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
|
|
||||||
with torch.no_grad():
|
|
||||||
inputs = self.process_inputs(pairs)
|
|
||||||
scores = self.compute_logits(inputs)
|
|
||||||
data = []
|
|
||||||
for i, s in enumerate(scores):
|
|
||||||
d = {
|
|
||||||
'index':i,
|
|
||||||
'relevance_score': s
|
|
||||||
}
|
|
||||||
data.append(d)
|
|
||||||
data = sorted(data,
|
|
||||||
key=lambda x: x["relevance_score"],
|
|
||||||
reverse=True)
|
|
||||||
if len(data) > top_n:
|
|
||||||
data = data[:top_n]
|
|
||||||
ret = {
|
|
||||||
"data": data,
|
|
||||||
"object": "rerank.result",
|
|
||||||
"model": self.model_name,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ret
|
|
@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
from llmengine.base_reranker import BaseReranker, llm_register
|
|
||||||
|
|
||||||
class BgeReranker(BaseReranker):
|
|
||||||
def __init__(self, model_id, max_length=8096):
|
|
||||||
if 'bge-reranker' not in model_id:
|
|
||||||
e = Exception(f'{model_id} is not a bge-reranker')
|
|
||||||
raise e
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
self.model_id = model_id
|
|
||||||
self.model_name = model_id.split('/')[-1]
|
|
||||||
|
|
||||||
def build_pairs(self, query, docs, **kw):
|
|
||||||
return [[query, doc] for doc in docs]
|
|
||||||
|
|
||||||
def process_inputs(self, pairs):
|
|
||||||
inputs = self.tokenizer(pairs, padding=True,
|
|
||||||
truncation=True, return_tensors='pt', max_length=512)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def compute_logits(self, inputs):
|
|
||||||
scores = self.model(**inputs,
|
|
||||||
return_dict=True).logits.view(-1, ).float()
|
|
||||||
scores = [ s.item() for s in scores ]
|
|
||||||
return scores
|
|
||||||
|
|
||||||
llm_register('bge-reranker', BgeReranker)
|
|
@ -1,212 +0,0 @@
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
||||||
from time import time
|
|
||||||
import torch
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
def is_chat_model(model_name: str, tokenizer) -> bool:
|
|
||||||
chat_keywords = ["chat", "chatml", "phi", "llama-chat", "mistral-instruct"]
|
|
||||||
if any(k in model_name.lower() for k in chat_keywords):
|
|
||||||
return True
|
|
||||||
if tokenizer and hasattr(tokenizer, "additional_special_tokens"):
|
|
||||||
if any(tag in tokenizer.additional_special_tokens for tag in ["<|user|>", "<|system|>", "<|assistant|>"]):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def build_chat_prompt(messages):
|
|
||||||
prompt = ""
|
|
||||||
for message in messages:
|
|
||||||
role = message["role"]
|
|
||||||
content = message["content"]
|
|
||||||
prompt += f"<|{role}|>\n{content}\n"
|
|
||||||
prompt += "<|assistant|>\n" # 生成开始
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
class CountingStreamer(TextIteratorStreamer):
|
|
||||||
def __init__(self, tokenizer, skip_prompt=True, **kw):
|
|
||||||
super().__init__(tokenizer, skip_prompt=skip_prompt, **kw)
|
|
||||||
self.token_count = 0
|
|
||||||
|
|
||||||
def __next__(self, *args, **kw):
|
|
||||||
output_ids = super().__iter__(*args, **kw)
|
|
||||||
self.token_count += output_ids.sequences.shape[1]
|
|
||||||
return output_ids
|
|
||||||
|
|
||||||
class TransformersChatEngine:
|
|
||||||
def __init__(self, model_name: str, device: str = None, fp16: bool = True,
|
|
||||||
output_json=True,
|
|
||||||
gpus: int = 1):
|
|
||||||
"""
|
|
||||||
通用大模型加载器,支持 GPU 数量与编号控制
|
|
||||||
:param model_name: 模型名称或路径
|
|
||||||
:param device: 指定设备如 "cuda:0",默认自动选择
|
|
||||||
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU)
|
|
||||||
:param gpus: 使用的 GPU 数量,1 表示单卡,>1 表示多卡推理(使用 device_map='auto')
|
|
||||||
"""
|
|
||||||
self.output_json = output_json
|
|
||||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
|
|
||||||
|
|
||||||
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
|
|
||||||
|
|
||||||
# Tokenizer 加载
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
|
|
||||||
# 模型加载
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
|
|
||||||
device_map="auto" if self.is_multi_gpu else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.is_multi_gpu:
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
self.is_chat = is_chat_model(model_name, self.tokenizer)
|
|
||||||
if self.is_chat:
|
|
||||||
self.messages = [ ]
|
|
||||||
|
|
||||||
print(f'{self.model.generation_config=}')
|
|
||||||
|
|
||||||
def set_system_prompt(self, prompt):
|
|
||||||
if self.is_chat:
|
|
||||||
self.messages = [{
|
|
||||||
|
|
||||||
'role': 'system',
|
|
||||||
'content': prompt
|
|
||||||
}]
|
|
||||||
def set_assistant_prompt(self, prompt):
|
|
||||||
if self.is_chat:
|
|
||||||
self.messages.append({
|
|
||||||
'role': 'assistant',
|
|
||||||
'content': prompt
|
|
||||||
})
|
|
||||||
def set_user_prompt(self, prompt):
|
|
||||||
if self.is_chat:
|
|
||||||
self.messages.append({
|
|
||||||
'role': 'user',
|
|
||||||
'content': prompt
|
|
||||||
})
|
|
||||||
return build_chat_prompt(self.messages)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def generate(self, prompt: str):
|
|
||||||
t1 = time()
|
|
||||||
prompt = self.set_user_prompt(prompt)
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
||||||
output_ids = self.model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=128,
|
|
||||||
generation_config=self.model.generation_config
|
|
||||||
)
|
|
||||||
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
||||||
t2 = time
|
|
||||||
text = output_text[len(prompt):] if output_text.startswith(prompt) else output_text
|
|
||||||
self.set_assistant_prompt(text)
|
|
||||||
if not self.output_json:
|
|
||||||
return text
|
|
||||||
input_tokens = inputs["input_ids"].shape[1]
|
|
||||||
output_tokens = len(self.tokenizer(text, return_tensors="pt")["input_ids"][0])
|
|
||||||
return {
|
|
||||||
'content':text,
|
|
||||||
'input_tokens': input_tokens,
|
|
||||||
'output_tokens': output_tokens,
|
|
||||||
'finish_time': t2 - t1,
|
|
||||||
'response_time': t2 - t1
|
|
||||||
}
|
|
||||||
|
|
||||||
def stream_generate(self, prompt: str):
|
|
||||||
t1 = time()
|
|
||||||
prompt = self.set_user_prompt(prompt)
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
||||||
input_tokens = inputs["input_ids"].shape[1]
|
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
||||||
|
|
||||||
generation_kwargs = dict(
|
|
||||||
**inputs,
|
|
||||||
streamer=streamer,
|
|
||||||
max_new_tokens=16000,
|
|
||||||
generation_config=self.model.generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
|
||||||
thread.start()
|
|
||||||
first = True
|
|
||||||
all_txt = ''
|
|
||||||
for new_text in streamer:
|
|
||||||
all_txt += new_text
|
|
||||||
if first:
|
|
||||||
t2 = time()
|
|
||||||
first = False
|
|
||||||
if not self.output_json:
|
|
||||||
yield new_text
|
|
||||||
yield {
|
|
||||||
'content': new_text,
|
|
||||||
'done': False
|
|
||||||
}
|
|
||||||
output_tokens = len(self.tokenizer(all_txt, return_tensors="pt")["input_ids"][0])
|
|
||||||
self.set_assistant_prompt(all_txt)
|
|
||||||
t3 = time()
|
|
||||||
if self.output_json:
|
|
||||||
yield {
|
|
||||||
'done': True,
|
|
||||||
'content':'',
|
|
||||||
'response_time': t2 - t1,
|
|
||||||
'finish_time': t3 - t1,
|
|
||||||
'input_tokens': input_tokens,
|
|
||||||
'output_tokens': output_tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description="Transformers Chat CLI")
|
|
||||||
parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称")
|
|
||||||
parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量")
|
|
||||||
parser.add_argument("--stream", action="store_true", help="是否流式输出")
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
def print_content(outd):
|
|
||||||
if isinstance(outd, dict):
|
|
||||||
print(outd['content'], end="", flush=True)
|
|
||||||
else:
|
|
||||||
print(outd, end="", flush=True)
|
|
||||||
|
|
||||||
def print_info(outd):
|
|
||||||
if isinstance(outd, dict):
|
|
||||||
if outd['done']:
|
|
||||||
print(f"response_time={outd['response_time']}, finish_time={outd['finish_time']}, input_tokens={outd['input_tokens']}, output_tokens={outd['output_tokens']}\n")
|
|
||||||
else:
|
|
||||||
print('\n');
|
|
||||||
|
|
||||||
def generate(engine, stream):
|
|
||||||
while True:
|
|
||||||
print('prompt("q" to exit):')
|
|
||||||
p = input()
|
|
||||||
if p == 'q':
|
|
||||||
break
|
|
||||||
if not p:
|
|
||||||
continue
|
|
||||||
if stream:
|
|
||||||
for outd in engine.stream_generate(p):
|
|
||||||
print_content(outd)
|
|
||||||
print('\n')
|
|
||||||
print_info(outd)
|
|
||||||
else:
|
|
||||||
outd = engine.generate(p)
|
|
||||||
print_content(outd)
|
|
||||||
print('\n')
|
|
||||||
print__info(outd)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
print(f'{args=}')
|
|
||||||
engine = TransformersChatEngine(
|
|
||||||
model_name=args.model,
|
|
||||||
gpus=args.gpus
|
|
||||||
)
|
|
||||||
generate(engine, args.stream)
|
|
||||||
|
|
||||||
main()
|
|
@ -1,59 +0,0 @@
|
|||||||
# for model mistralai/Devstral-Small-2505
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.log import debug
|
|
||||||
from ahserver.serverenv import get_serverenv
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
||||||
from mistral_common.protocol.instruct.messages import (
|
|
||||||
SystemMessage, UserMessage, AssistantMessage
|
|
||||||
)
|
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
|
||||||
|
|
||||||
class DevstralLLM(T2TChatLLM):
|
|
||||||
def __init__(self, model_id):
|
|
||||||
tekken_file = f'{model_id}/tekken.json'
|
|
||||||
self.tokenizer = MistralTokenizer.from_file(tekken_file)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
torch_dtype="auto",
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
def _build_assistant_message(self, prompt):
|
|
||||||
return AssistantMessage(content=prompt)
|
|
||||||
|
|
||||||
def _build_sys_message(self, prompt):
|
|
||||||
return SystemMessage(content=prompt)
|
|
||||||
|
|
||||||
def _build_user_message(self, prompt, **kw):
|
|
||||||
return UserMessage(content=prompt)
|
|
||||||
|
|
||||||
def get_streamer(self):
|
|
||||||
return TextIteratorStreamer(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
skip_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
|
||||||
generate_kwargs = dict(
|
|
||||||
**inputs,
|
|
||||||
streamer=streamer,
|
|
||||||
max_new_tokens=32768,
|
|
||||||
do_sample=True
|
|
||||||
)
|
|
||||||
return generate_kwargs
|
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
|
||||||
tokenized = self.tokenizer.encode_chat_completion(
|
|
||||||
ChatCompletionRequest(messages=messages)
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
'input_ids': torch.tensor([tokenized.tokens])
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_register('mistralai/Devstral', DevstralLLM)
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
|||||||
from traceback import format_exc
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
from llmengine.qwen3embedding import *
|
|
||||||
from llmengine.base_embedding import get_llm_class
|
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.log import debug, exception
|
|
||||||
from ahserver.serverenv import ServerEnv
|
|
||||||
from ahserver.globalEnv import stream_response
|
|
||||||
from ahserver.webapp import webserver
|
|
||||||
|
|
||||||
from aiohttp_session import get_session
|
|
||||||
|
|
||||||
helptext = """embeddings api:
|
|
||||||
path: /v1/embeddings
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data: {
|
|
||||||
"input": "this is a test"
|
|
||||||
}
|
|
||||||
or {
|
|
||||||
"input":[
|
|
||||||
"this is first sentence",
|
|
||||||
"this is second setence"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
response is a json
|
|
||||||
{
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": 0,
|
|
||||||
"embedding": [0.0123, -0.0456, ...]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"model": "text-embedding-3-small",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
rf = RegisterFunction()
|
|
||||||
rf.register('embeddings', embeddings)
|
|
||||||
rf.register('docs', docs)
|
|
||||||
|
|
||||||
async def docs(request, params_kw, *params, **kw):
|
|
||||||
return helptext
|
|
||||||
|
|
||||||
async def embeddings(request, params_kw, *params, **kw):
|
|
||||||
debug(f'{params_kw.input=}')
|
|
||||||
se = ServerEnv()
|
|
||||||
engine = se.engine
|
|
||||||
f = awaitify(engine.embeddings)
|
|
||||||
input = params_kw.input
|
|
||||||
if input is None:
|
|
||||||
e = exception(f'input is None')
|
|
||||||
raise e
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [input]
|
|
||||||
arr = await f(input)
|
|
||||||
debug(f'{arr=}, type(arr)')
|
|
||||||
return arr
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(prog="Embedding")
|
|
||||||
parser.add_argument('-w', '--workdir')
|
|
||||||
parser.add_argument('-p', '--port')
|
|
||||||
parser.add_argument('model_path')
|
|
||||||
args = parser.parse_args()
|
|
||||||
Klass = get_llm_class(args.model_path)
|
|
||||||
if Klass is None:
|
|
||||||
e = Exception(f'{args.model_path} has not mapping to a model class')
|
|
||||||
exception(f'{e}, {format_exc()}')
|
|
||||||
raise e
|
|
||||||
se = ServerEnv()
|
|
||||||
se.engine = Klass(args.model_path)
|
|
||||||
se.engine.use_mps_if_prosible()
|
|
||||||
workdir = args.workdir or os.getcwd()
|
|
||||||
port = args.port
|
|
||||||
debug(f'{args=}')
|
|
||||||
webserver(init, workdir, port)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,87 +0,0 @@
|
|||||||
from traceback import format_exc
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
from llmengine.ltpentity import *
|
|
||||||
from llmengine.base_entity import get_ltp_class
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.log import debug, exception
|
|
||||||
from ahserver.serverenv import ServerEnv
|
|
||||||
from ahserver.globalEnv import stream_response
|
|
||||||
from ahserver.webapp import webserver
|
|
||||||
|
|
||||||
from aiohttp_session import get_session
|
|
||||||
|
|
||||||
helptext = """LTP Entities API:
|
|
||||||
|
|
||||||
1. Entities Endpoint:
|
|
||||||
path: /v1/entities
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data: {
|
|
||||||
"query": "苹果公司在北京开设新店"
|
|
||||||
}
|
|
||||||
response: {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
"苹果公司",
|
|
||||||
"北京",
|
|
||||||
"新店",
|
|
||||||
"开设",
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
2. Docs Endpoint:
|
|
||||||
path: /v1/docs
|
|
||||||
response: This help text
|
|
||||||
"""
|
|
||||||
|
|
||||||
def init():
|
|
||||||
rf = RegisterFunction()
|
|
||||||
rf.register('entities', entities)
|
|
||||||
rf.register('docs', docs)
|
|
||||||
|
|
||||||
async def docs(request, params_kw, *params, **kw):
|
|
||||||
return helptext
|
|
||||||
|
|
||||||
async def entities(request, params_kw, *params, **kw):
|
|
||||||
debug(f'{params_kw.query=}')
|
|
||||||
se = ServerEnv()
|
|
||||||
engine = se.engine
|
|
||||||
f = awaitify(engine.extract_entities)
|
|
||||||
query = params_kw.query
|
|
||||||
if query is None:
|
|
||||||
e = exception(f'query is None')
|
|
||||||
raise e
|
|
||||||
entities = await f(query)
|
|
||||||
debug(f'{entities=}, type(entities)')
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": entities
|
|
||||||
}
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(prog="LTP Entity Service")
|
|
||||||
parser.add_argument('-w', '--workdir')
|
|
||||||
parser.add_argument('-p', '--port')
|
|
||||||
parser.add_argument('model_path')
|
|
||||||
args = parser.parse_args()
|
|
||||||
Klass = get_ltp_class(args.model_path)
|
|
||||||
if Klass is None:
|
|
||||||
e = Exception(f'{args.model_path} has not mapping to a model class')
|
|
||||||
exception(f'{e}, {format_exc()}')
|
|
||||||
raise e
|
|
||||||
se = ServerEnv()
|
|
||||||
se.engine = Klass(args.model_path)
|
|
||||||
workdir = args.workdir or os.getcwd()
|
|
||||||
port = args.port
|
|
||||||
debug(f'{args=}')
|
|
||||||
webserver(init, workdir, port)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,44 +0,0 @@
|
|||||||
#!/share/vllm-0.8.5/bin/python
|
|
||||||
|
|
||||||
# pip install accelerate
|
|
||||||
import threading
|
|
||||||
from time import time
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from ahserver.serverenv import get_serverenv
|
|
||||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
|
||||||
|
|
||||||
class Gemma3LLM(MMChatLLM):
|
|
||||||
def __init__(self, model_id):
|
|
||||||
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
||||||
model_id, device_map="auto"
|
|
||||||
).eval()
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
||||||
self.tokenizer = self.processor.tokenizer
|
|
||||||
self.messages = []
|
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
llm_register("gemma-3", Gemma3LLM)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
|
||||||
session = {}
|
|
||||||
while True:
|
|
||||||
print('input prompt')
|
|
||||||
p = input()
|
|
||||||
if p:
|
|
||||||
if p == 'q':
|
|
||||||
break;
|
|
||||||
print('input image path')
|
|
||||||
imgpath=input()
|
|
||||||
for d in gemma3.stream_generate(session, p, image_path=imgpath):
|
|
||||||
if not d['done']:
|
|
||||||
print(d['text'], end='', flush=True)
|
|
||||||
else:
|
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
|
||||||
print(f'\n{x}\n')
|
|
||||||
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
|||||||
# Requires ltp>=0.2.0
|
|
||||||
|
|
||||||
from ltp import LTP
|
|
||||||
from typing import List
|
|
||||||
import logging
|
|
||||||
from llmengine.base_entity import BaseLtp, ltp_register
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class LtpEntity(BaseLtp):
|
|
||||||
def __init__(self, model_id):
|
|
||||||
# Load LTP model for CWS, POS, and NER
|
|
||||||
self.ltp = LTP(model_id)
|
|
||||||
self.model_id = model_id
|
|
||||||
self.model_name = model_id.split('/')[-1]
|
|
||||||
|
|
||||||
def extract_entities(self, query: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
从查询文本中抽取实体,包括:
|
|
||||||
- LTP NER 识别的实体(所有类型)。
|
|
||||||
- LTP POS 标注为名词('n')的词。
|
|
||||||
- LTP POS 标注为动词('v')的词。
|
|
||||||
- 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not query:
|
|
||||||
raise ValueError("查询文本不能为空")
|
|
||||||
|
|
||||||
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
|
||||||
words = result.cws[0]
|
|
||||||
pos_list = result.pos[0]
|
|
||||||
ner = result.ner[0]
|
|
||||||
|
|
||||||
entities = []
|
|
||||||
subword_set = set()
|
|
||||||
|
|
||||||
logger.debug(f"NER 结果: {ner}")
|
|
||||||
for entity_type, entity, start, end in ner:
|
|
||||||
entities.append(entity)
|
|
||||||
|
|
||||||
combined = ""
|
|
||||||
combined_words = []
|
|
||||||
for i in range(len(words)):
|
|
||||||
if pos_list[i] == 'n':
|
|
||||||
combined += words[i]
|
|
||||||
combined_words.append(words[i])
|
|
||||||
if i + 1 < len(words) and pos_list[i + 1] == 'n':
|
|
||||||
continue
|
|
||||||
if combined:
|
|
||||||
entities.append(combined)
|
|
||||||
subword_set.update(combined_words)
|
|
||||||
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
|
||||||
combined = ""
|
|
||||||
combined_words = []
|
|
||||||
else:
|
|
||||||
combined = ""
|
|
||||||
combined_words = []
|
|
||||||
logger.debug(f"连续名词子词集合: {subword_set}")
|
|
||||||
|
|
||||||
for word, pos in zip(words, pos_list):
|
|
||||||
if pos == 'n' and word not in subword_set:
|
|
||||||
entities.append(word)
|
|
||||||
|
|
||||||
for word, pos in zip(words, pos_list):
|
|
||||||
if pos == 'v':
|
|
||||||
entities.append(word)
|
|
||||||
|
|
||||||
unique_entities = list(dict.fromkeys(entities))
|
|
||||||
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
|
||||||
return unique_entities
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"实体抽取失败: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
ltp_register('LTP', LtpEntity)
|
|
@ -1,53 +0,0 @@
|
|||||||
# pip install accelerate
|
|
||||||
import time
|
|
||||||
from transformers import AutoProcessor, AutoModelForImageTextToText
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
|
||||||
|
|
||||||
model_id = "google/medgemma-4b-it"
|
|
||||||
|
|
||||||
class MedgemmaLLM(MMChatLLM):
|
|
||||||
def __init__(self, model_id):
|
|
||||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
|
||||||
self.tokenizer = self.processor.tokenizer
|
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
|
||||||
inputs = self.processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt"
|
|
||||||
).to(self.model.device, dtype=torch.bfloat16)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
llm_register("google/medgemma", MedgemmaLLM)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
|
||||||
session = {}
|
|
||||||
while True:
|
|
||||||
print(f'chat with {med.model_id}')
|
|
||||||
print('input prompt')
|
|
||||||
p = input()
|
|
||||||
if p:
|
|
||||||
if p == 'q':
|
|
||||||
break;
|
|
||||||
print('input image path')
|
|
||||||
imgpath=input()
|
|
||||||
for d in med.stream_generate(session, p, image_path=imgpath):
|
|
||||||
if not d['done']:
|
|
||||||
print(d['text'], end='', flush=True)
|
|
||||||
else:
|
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
|
||||||
print(f'\n{x}\n')
|
|
||||||
|
|
||||||
|
|
@ -1,68 +0,0 @@
|
|||||||
#!/share/vllm-0.8.5/bin/python
|
|
||||||
|
|
||||||
# pip install accelerate
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.log import debug
|
|
||||||
from ahserver.serverenv import get_serverenv
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
|
||||||
|
|
||||||
class Qwen3LLM(T2TChatLLM):
|
|
||||||
def __init__(self, model_id):
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
torch_dtype="auto",
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
device = torch.device("mps")
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.model_id = model_id
|
|
||||||
|
|
||||||
def build_kwargs(self, inputs, streamer):
|
|
||||||
generate_kwargs = dict(
|
|
||||||
**inputs,
|
|
||||||
streamer=streamer,
|
|
||||||
max_new_tokens=32768,
|
|
||||||
do_sample=True,
|
|
||||||
eos_token_id=self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
return generate_kwargs
|
|
||||||
|
|
||||||
def _messages2inputs(self, messages):
|
|
||||||
debug(f'{messages=}')
|
|
||||||
text = self.tokenizer.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
enable_thinking=True
|
|
||||||
)
|
|
||||||
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
|
||||||
|
|
||||||
llm_register("Qwen/Qwen3", Qwen3LLM)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import sys
|
|
||||||
model_path = sys.argv[1]
|
|
||||||
q3 = Qwen3LLM(model_path)
|
|
||||||
session = {}
|
|
||||||
while True:
|
|
||||||
print('input prompt')
|
|
||||||
p = input()
|
|
||||||
if p:
|
|
||||||
if p == 'q':
|
|
||||||
break;
|
|
||||||
for d in q3.stream_generate(session, p):
|
|
||||||
print(d)
|
|
||||||
"""
|
|
||||||
if not d['done']:
|
|
||||||
print(d['text'], end='', flush=True)
|
|
||||||
else:
|
|
||||||
x = {k:v for k,v in d.items() if k != 'text'}
|
|
||||||
print(f'\n{x}\n')
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
|
||||||
from llmengine.base_reranker import BaseReranker, llm_register
|
|
||||||
|
|
||||||
class Qwen3Reranker(BaseReranker):
|
|
||||||
def __init__(self, model_id, max_length=8096):
|
|
||||||
if 'Qwen3-Reranker' not in model_id:
|
|
||||||
e = Exception(f'{model_id} is not a Qwen3-Reranker')
|
|
||||||
raise e
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_id).eval()
|
|
||||||
self.model_id = model_id
|
|
||||||
self.model_name = model_id.split('/')[-1]
|
|
||||||
self.max_length = 8192
|
|
||||||
|
|
||||||
llm_register('Qwen3-Reranker', Qwen3Reranker)
|
|
@ -1,22 +0,0 @@
|
|||||||
# Requires transformers>=4.51.0
|
|
||||||
# Requires sentence-transformers>=2.7.0
|
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from llmengine.base_embedding import BaseEmbedding, llm_register
|
|
||||||
|
|
||||||
class Qwen3Embedding(BaseEmbedding):
|
|
||||||
def __init__(self, model_id, max_length=8096):
|
|
||||||
# Load the model
|
|
||||||
self.model = SentenceTransformer(model_id)
|
|
||||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving,
|
|
||||||
# together with setting `padding_side` to "left":
|
|
||||||
# model = SentenceTransformer(
|
|
||||||
# "Qwen/Qwen3-Embedding-0.6B",
|
|
||||||
# model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto"},
|
|
||||||
# tokenizer_kwargs={"padding_side": "left"},
|
|
||||||
# )
|
|
||||||
self.max_length = max_length
|
|
||||||
self.model_id = model_id
|
|
||||||
self.model_name = model_id.split('/')[-1]
|
|
||||||
|
|
||||||
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
|
@ -1,106 +0,0 @@
|
|||||||
from traceback import format_exc
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
from llmengine.qwen3_reranker import *
|
|
||||||
from llmengine.bge_reranker import *
|
|
||||||
from llmengine.base_reranker import get_llm_class
|
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
|
||||||
from appPublic.worker import awaitify
|
|
||||||
from appPublic.log import debug, exception
|
|
||||||
from ahserver.serverenv import ServerEnv
|
|
||||||
from ahserver.webapp import webserver
|
|
||||||
|
|
||||||
helptext = """rerank api:
|
|
||||||
path: /v1/rerank
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data:
|
|
||||||
{
|
|
||||||
"model": "rerank-001",
|
|
||||||
"query": "什么是量子计算?",
|
|
||||||
"documents": [
|
|
||||||
"量子计算是一种使用量子比特进行计算的方式。",
|
|
||||||
"古典计算机使用的是二进制位。",
|
|
||||||
"天气预报依赖于统计模型。",
|
|
||||||
"量子计算与物理学密切相关。"
|
|
||||||
},
|
|
||||||
"top_n": 2
|
|
||||||
}
|
|
||||||
|
|
||||||
response is a json
|
|
||||||
{
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"relevance_score": 0.95
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"index": 3,
|
|
||||||
"relevance_score": 0.89
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"object": "rerank.result",
|
|
||||||
"model": "rerank-001",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
rf = RegisterFunction()
|
|
||||||
rf.register('rerank', rerank)
|
|
||||||
rf.register('docs', docs)
|
|
||||||
|
|
||||||
async def docs(request, params_kw, *params, **kw):
|
|
||||||
return helptext
|
|
||||||
|
|
||||||
async def rerank(request, params_kw, *params, **kw):
|
|
||||||
debug(f'{params_kw.query=}, {params_kw.documents=}, {params_kw.top_n=}')
|
|
||||||
se = ServerEnv()
|
|
||||||
engine = se.engine
|
|
||||||
f = awaitify(engine.rerank)
|
|
||||||
query = params_kw.query
|
|
||||||
if query is None:
|
|
||||||
e = Exception(f'query is None')
|
|
||||||
raise e
|
|
||||||
documents = params_kw.documents
|
|
||||||
if documents is None:
|
|
||||||
e = Exception(f'documents is None')
|
|
||||||
raise e
|
|
||||||
if isinstance(documents, str):
|
|
||||||
documents = [documents]
|
|
||||||
top_n = params_kw.top_n
|
|
||||||
if top_n is None:
|
|
||||||
top_n = 5
|
|
||||||
arr = await f(query, params_kw.documents, top_n)
|
|
||||||
debug(f'{arr=}, type(arr)')
|
|
||||||
return arr
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(prog="Rerank")
|
|
||||||
parser.add_argument('-w', '--workdir')
|
|
||||||
parser.add_argument('-p', '--port')
|
|
||||||
parser.add_argument('model_path')
|
|
||||||
args = parser.parse_args()
|
|
||||||
Klass = get_llm_class(args.model_path)
|
|
||||||
if Klass is None:
|
|
||||||
e = Exception(f'{args.model_path} has not mapping to a model class')
|
|
||||||
exception(f'{e}, {format_exc()}')
|
|
||||||
raise e
|
|
||||||
se = ServerEnv()
|
|
||||||
se.engine = Klass(args.model_path)
|
|
||||||
se.engine.use_mps_if_prosible()
|
|
||||||
workdir = args.workdir or os.getcwd()
|
|
||||||
port = args.port
|
|
||||||
debug(f'{args=}')
|
|
||||||
webserver(init, workdir, port)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
|||||||
from traceback import format_exc
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from llmengine.base_chat_llm import BaseChatLLM, get_llm_class
|
|
||||||
from llmengine.gemma3_it import Gemma3LLM
|
|
||||||
from llmengine.medgemma3_it import MedgemmaLLM
|
|
||||||
from llmengine.qwen3 import Qwen3LLM
|
|
||||||
|
|
||||||
from appPublic.registerfunction import RegisterFunction
|
|
||||||
from appPublic.log import debug, exception
|
|
||||||
from ahserver.serverenv import ServerEnv
|
|
||||||
from ahserver.globalEnv import stream_response
|
|
||||||
from ahserver.webapp import webserver
|
|
||||||
|
|
||||||
from aiohttp_session import get_session
|
|
||||||
|
|
||||||
def init():
|
|
||||||
rf = RegisterFunction()
|
|
||||||
rf.register('chat_completions', chat_completions)
|
|
||||||
|
|
||||||
async def chat_completions(request, params_kw, *params, **kw):
|
|
||||||
async def gor():
|
|
||||||
se = ServerEnv()
|
|
||||||
engine = se.chat_engine
|
|
||||||
session = await get_session(request)
|
|
||||||
kwargs = {
|
|
||||||
}
|
|
||||||
if params_kw.image_path:
|
|
||||||
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
|
|
||||||
if params_kw.video_path:
|
|
||||||
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
|
|
||||||
if params_kw.audio_path:
|
|
||||||
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
|
||||||
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
|
||||||
debug(f'{d=}')
|
|
||||||
yield d
|
|
||||||
|
|
||||||
return await stream_response(request, gor)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(prog="Sage")
|
|
||||||
parser.add_argument('-w', '--workdir')
|
|
||||||
parser.add_argument('-p', '--port')
|
|
||||||
parser.add_argument('model_path')
|
|
||||||
args = parser.parse_args()
|
|
||||||
Klass = get_llm_class(args.model_path)
|
|
||||||
if Klass is None:
|
|
||||||
e = Exception(f'{args.model_path} has not mapping to a model class')
|
|
||||||
exception(f'{e}, {format_exc()}')
|
|
||||||
raise e
|
|
||||||
se = ServerEnv()
|
|
||||||
se.engine = Klass(args.model_path)
|
|
||||||
se.engine.use_mps_if_prosible()
|
|
||||||
workdir = args.workdir or os.getcwd()
|
|
||||||
port = args.port
|
|
||||||
webserver(init, workdir, port)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
Metadata-Version: 2.4
|
|
||||||
Name: llmengine
|
|
||||||
Version: 0.0.1
|
|
||||||
Summary: Your project description
|
|
||||||
Author-email: yu moqing <yumoqing@gmail.com>
|
|
||||||
License: MIT
|
|
||||||
Requires-Python: >=3.8
|
|
||||||
Description-Content-Type: text/markdown
|
|
||||||
Requires-Dist: torch
|
|
||||||
Requires-Dist: transformers
|
|
||||||
Requires-Dist: sentence-transformers>=2.7.0
|
|
||||||
Requires-Dist: mistral-common
|
|
||||||
Requires-Dist: accelerate
|
|
||||||
Provides-Extra: dev
|
|
||||||
Requires-Dist: pytest; extra == "dev"
|
|
||||||
Requires-Dist: black; extra == "dev"
|
|
||||||
Requires-Dist: mypy; extra == "dev"
|
|
@ -1,26 +0,0 @@
|
|||||||
README.md
|
|
||||||
pyproject.toml
|
|
||||||
llmengine/__init__.py
|
|
||||||
llmengine/ahserver.py
|
|
||||||
llmengine/base_chat_llm.py
|
|
||||||
llmengine/base_embedding.py
|
|
||||||
llmengine/base_entity.py
|
|
||||||
llmengine/base_reranker.py
|
|
||||||
llmengine/bge_reranker.py
|
|
||||||
llmengine/chatllm.py
|
|
||||||
llmengine/devstral.py
|
|
||||||
llmengine/embedding.py
|
|
||||||
llmengine/entity.py
|
|
||||||
llmengine/gemma3_it.py
|
|
||||||
llmengine/ltpentity.py
|
|
||||||
llmengine/medgemma3_it.py
|
|
||||||
llmengine/qwen3.py
|
|
||||||
llmengine/qwen3_reranker.py
|
|
||||||
llmengine/qwen3embedding.py
|
|
||||||
llmengine/rerank.py
|
|
||||||
llmengine/server.py
|
|
||||||
llmengine.egg-info/PKG-INFO
|
|
||||||
llmengine.egg-info/SOURCES.txt
|
|
||||||
llmengine.egg-info/dependency_links.txt
|
|
||||||
llmengine.egg-info/requires.txt
|
|
||||||
llmengine.egg-info/top_level.txt
|
|
@ -1 +0,0 @@
|
|||||||
|
|
@ -1,10 +0,0 @@
|
|||||||
torch
|
|
||||||
transformers
|
|
||||||
sentence-transformers>=2.7.0
|
|
||||||
mistral-common
|
|
||||||
accelerate
|
|
||||||
|
|
||||||
[dev]
|
|
||||||
pytest
|
|
||||||
black
|
|
||||||
mypy
|
|
@ -1 +0,0 @@
|
|||||||
llmengine
|
|
Loading…
Reference in New Issue
Block a user