增加了llmengine/base_entity,llmengine/ltpentity,llmengine/entity
This commit is contained in:
parent
fcaee5e657
commit
5e6b03fbe5
0
build/lib/llmengine/__init__.py
Normal file
0
build/lib/llmengine/__init__.py
Normal file
5
build/lib/llmengine/ahserver.py
Normal file
5
build/lib/llmengine/ahserver.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = ConfiguredServer()
|
||||||
|
server.run()
|
246
build/lib/llmengine/base_chat_llm.py
Normal file
246
build/lib/llmengine/base_chat_llm.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from time import time
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
from appPublic.log import debug
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
|
||||||
|
model_pathMap = {
|
||||||
|
}
|
||||||
|
def llm_register(model_key, Klass):
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_llm_class(model_path):
|
||||||
|
for k,klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BaseChatLLM:
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def get_session_key(self):
|
||||||
|
return self.model_id + ':messages'
|
||||||
|
|
||||||
|
def _get_session_messages(self, session):
|
||||||
|
key = self.get_session_key()
|
||||||
|
messages = session.get(key) or []
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _set_session_messages(self, session, messages):
|
||||||
|
key = self.get_session_key()
|
||||||
|
session[key] = messages
|
||||||
|
|
||||||
|
def get_streamer(self):
|
||||||
|
return TextIteratorStreamer(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
skip_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def output_generator(self, streamer):
|
||||||
|
all_txt = ''
|
||||||
|
t1 = time()
|
||||||
|
i = 0
|
||||||
|
id = f'chatllm-{getID}'
|
||||||
|
for txt in streamer:
|
||||||
|
if txt == '':
|
||||||
|
continue
|
||||||
|
if i == 0:
|
||||||
|
t2 = time()
|
||||||
|
i += 1
|
||||||
|
all_txt += txt
|
||||||
|
yield {
|
||||||
|
"id":id,
|
||||||
|
"object":"chat.completion.chunk",
|
||||||
|
"created":time(),
|
||||||
|
"model":self.model_id,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{
|
||||||
|
"content":txt
|
||||||
|
},
|
||||||
|
"logprobs":None,
|
||||||
|
"finish_reason":None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
t3 = time()
|
||||||
|
t = all_txt
|
||||||
|
unk = self.tokenizer(t, return_tensors="pt")
|
||||||
|
output_tokens = len(unk["input_ids"][0])
|
||||||
|
yield {
|
||||||
|
"id":id,
|
||||||
|
"object":"chat.completion.chunk",
|
||||||
|
"created":time(),
|
||||||
|
"model":self.model_id,
|
||||||
|
"response_time": t2 - t1,
|
||||||
|
"finish_time": t3 - t1,
|
||||||
|
"output_token": output_tokens,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{
|
||||||
|
"content":""
|
||||||
|
},
|
||||||
|
"logprobs":None,
|
||||||
|
"finish_reason":"stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
messages = self._get_session_messages(session)
|
||||||
|
if sys_prompt:
|
||||||
|
messages.append(self._build_sys_message(sys_prompt))
|
||||||
|
messages.append(self._build_user_message(prompt, image_path=image_path))
|
||||||
|
# debug(f'{messages=}')
|
||||||
|
all_txt = ''
|
||||||
|
for d in self._gen(messages):
|
||||||
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
|
messages.append(self._build_assistant_message(all_txt))
|
||||||
|
else:
|
||||||
|
all_txt += d['choices'][0]['delta']['content']
|
||||||
|
yield d
|
||||||
|
self._set_session_messages(session, messages)
|
||||||
|
|
||||||
|
async def _async_generator(self, session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
yield d
|
||||||
|
|
||||||
|
def generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
|
return d
|
||||||
|
def stream_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
for d in self._generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
s = f'data: {json.dumps(d)}\n'
|
||||||
|
yield s
|
||||||
|
|
||||||
|
async def async_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
|
return d
|
||||||
|
|
||||||
|
async def async_stream_generate(self, session, prompt,
|
||||||
|
image_path=None,
|
||||||
|
video_path=None,
|
||||||
|
audio_path=None,
|
||||||
|
sys_prompt=None):
|
||||||
|
async for d in self._async_generator(session, prompt, image_path, video_path, audio_path, sys_prompt):
|
||||||
|
s = f'data: {json.dumps(d)}\n'
|
||||||
|
yield s
|
||||||
|
yield 'data: [DONE]'
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=512,
|
||||||
|
do_sample=True,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
return self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True, return_tensors="pt"
|
||||||
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def _gen(self, messages):
|
||||||
|
inputs = self._messages2inputs(messages)
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
streamer = self.get_streamer()
|
||||||
|
kwargs = self.build_kwargs(inputs, streamer)
|
||||||
|
thread = threading.Thread(target=self.model.generate,
|
||||||
|
kwargs=kwargs)
|
||||||
|
thread.start()
|
||||||
|
for d in self.output_generator(streamer):
|
||||||
|
if d['choices'][0]['finish_reason'] == 'stop':
|
||||||
|
d['input_tokens'] = input_len
|
||||||
|
yield d
|
||||||
|
|
||||||
|
class T2TChatLLM(BaseChatLLM):
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, **kw):
|
||||||
|
return {
|
||||||
|
"role":"user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
class MMChatLLM(BaseChatLLM):
|
||||||
|
""" multiple modal chat LLM """
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"assistant",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return {
|
||||||
|
"role":"system",
|
||||||
|
"content":[{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, image_path=None,
|
||||||
|
video_path=None, audio_path=None):
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"type":"text", "text": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if image_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "image",
|
||||||
|
"image": image_path
|
||||||
|
})
|
||||||
|
if video_path:
|
||||||
|
contents.append({
|
||||||
|
"type": "video",
|
||||||
|
"video":video_path
|
||||||
|
})
|
||||||
|
if audio_path:
|
||||||
|
contents.append({
|
||||||
|
"tyoe": "audio",
|
||||||
|
"audio": audio_path
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"role": "user",
|
||||||
|
"content": contents
|
||||||
|
}
|
||||||
|
|
46
build/lib/llmengine/base_embedding.py
Normal file
46
build/lib/llmengine/base_embedding.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
model_pathMap = {
|
||||||
|
}
|
||||||
|
def llm_register(model_key, Klass):
|
||||||
|
global model_pathMap
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_llm_class(model_path):
|
||||||
|
for k,klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BaseEmbedding:
|
||||||
|
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def embeddings(self, input):
|
||||||
|
es = self.model.encode(input)
|
||||||
|
data = []
|
||||||
|
for i, e in enumerate(es):
|
||||||
|
d = {
|
||||||
|
"object": "embedding",
|
||||||
|
"index": i,
|
||||||
|
"embedding": e.tolist()
|
||||||
|
}
|
||||||
|
data.append(d)
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": self.model_name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def similarity(self, qvector, dcovectors):
|
||||||
|
s = self.model.similarity([qvector], docvectors)
|
||||||
|
return s[0]
|
||||||
|
|
23
build/lib/llmengine/base_entity.py
Normal file
23
build/lib/llmengine/base_entity.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
model_pathMap = {}
|
||||||
|
|
||||||
|
def ltp_register(model_key, Klass):
|
||||||
|
"""Register a model class for a given model key."""
|
||||||
|
global model_pathMap
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_ltp_class(model_path):
|
||||||
|
"""Find the model class for a given model path."""
|
||||||
|
for k, klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BaseLtp(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def extract_entities(self, query: str) -> List[str]:
|
||||||
|
"""Extract entities from query text."""
|
||||||
|
pass
|
84
build/lib/llmengine/base_reranker.py
Normal file
84
build/lib/llmengine/base_reranker.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
model_pathMap = {
|
||||||
|
}
|
||||||
|
def llm_register(model_key, Klass):
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_llm_class(model_path):
|
||||||
|
for k,klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BaseReranker:
|
||||||
|
def __init__(self, model_id, **kw):
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def use_mps_if_prosible(self):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def process_inputs(self, pairs):
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
pairs, padding=False, truncation='longest_first',
|
||||||
|
return_attention_mask=False, max_length=self.max_length
|
||||||
|
)
|
||||||
|
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=self.max_length)
|
||||||
|
for key in inputs:
|
||||||
|
inputs[key] = inputs[key].to(self.model.device)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def build_sys_prompt(self, sys_prompt):
|
||||||
|
return f"<|im_start|>system\n{sys_prompt}\n<|im_end|>"
|
||||||
|
|
||||||
|
def build_user_prompt(self, query, doc, instruct=''):
|
||||||
|
return f'<|im_start|>user\n<Instruct>: {instruct}\n<Query>:{query}\n<Document>:\n{doc}<|im_end|>'
|
||||||
|
|
||||||
|
def build_assistant_prompt(self):
|
||||||
|
return "<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
def compute_logits(self, inputs, **kwargs):
|
||||||
|
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||||
|
# true_vector = batch_scores[:, token_true_id]
|
||||||
|
# false_vector = batch_scores[:, token_false_id]
|
||||||
|
# batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||||
|
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||||
|
scores = batch_scores[:, 1].exp().tolist()
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def build_pairs(self, query, docs, sys_prompt="", task=""):
|
||||||
|
sys_str = self.build_sys_prompt(sys_prompt)
|
||||||
|
ass_str = self.build_assistant_prompt()
|
||||||
|
pairs = [ sys_str + '\n' + self.build_user_prompt(task, query, doc) + '\n' + ass_str for doc in docs ]
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def rerank(self, query, docs, top_n, sys_prompt="", task=""):
|
||||||
|
pairs = self.build_pairs(query, docs, sys_prompt=sys_prompt, task=task)
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.process_inputs(pairs)
|
||||||
|
scores = self.compute_logits(inputs)
|
||||||
|
data = []
|
||||||
|
for i, s in enumerate(scores):
|
||||||
|
d = {
|
||||||
|
'index':i,
|
||||||
|
'relevance_score': s
|
||||||
|
}
|
||||||
|
data.append(d)
|
||||||
|
data = sorted(data,
|
||||||
|
key=lambda x: x["relevance_score"],
|
||||||
|
reverse=True)
|
||||||
|
if len(data) > top_n:
|
||||||
|
data = data[:top_n]
|
||||||
|
ret = {
|
||||||
|
"data": data,
|
||||||
|
"object": "rerank.result",
|
||||||
|
"model": self.model_name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
31
build/lib/llmengine/bge_reranker.py
Normal file
31
build/lib/llmengine/bge_reranker.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
from llmengine.base_reranker import BaseReranker, llm_register
|
||||||
|
|
||||||
|
class BgeReranker(BaseReranker):
|
||||||
|
def __init__(self, model_id, max_length=8096):
|
||||||
|
if 'bge-reranker' not in model_id:
|
||||||
|
e = Exception(f'{model_id} is not a bge-reranker')
|
||||||
|
raise e
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
|
model.eval()
|
||||||
|
self.model = model
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
|
def build_pairs(self, query, docs, **kw):
|
||||||
|
return [[query, doc] for doc in docs]
|
||||||
|
|
||||||
|
def process_inputs(self, pairs):
|
||||||
|
inputs = self.tokenizer(pairs, padding=True,
|
||||||
|
truncation=True, return_tensors='pt', max_length=512)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def compute_logits(self, inputs):
|
||||||
|
scores = self.model(**inputs,
|
||||||
|
return_dict=True).logits.view(-1, ).float()
|
||||||
|
scores = [ s.item() for s in scores ]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
llm_register('bge-reranker', BgeReranker)
|
212
build/lib/llmengine/chatllm.py
Normal file
212
build/lib/llmengine/chatllm.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||||
|
from time import time
|
||||||
|
import torch
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
def is_chat_model(model_name: str, tokenizer) -> bool:
|
||||||
|
chat_keywords = ["chat", "chatml", "phi", "llama-chat", "mistral-instruct"]
|
||||||
|
if any(k in model_name.lower() for k in chat_keywords):
|
||||||
|
return True
|
||||||
|
if tokenizer and hasattr(tokenizer, "additional_special_tokens"):
|
||||||
|
if any(tag in tokenizer.additional_special_tokens for tag in ["<|user|>", "<|system|>", "<|assistant|>"]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def build_chat_prompt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
prompt += f"<|{role}|>\n{content}\n"
|
||||||
|
prompt += "<|assistant|>\n" # 生成开始
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
class CountingStreamer(TextIteratorStreamer):
|
||||||
|
def __init__(self, tokenizer, skip_prompt=True, **kw):
|
||||||
|
super().__init__(tokenizer, skip_prompt=skip_prompt, **kw)
|
||||||
|
self.token_count = 0
|
||||||
|
|
||||||
|
def __next__(self, *args, **kw):
|
||||||
|
output_ids = super().__iter__(*args, **kw)
|
||||||
|
self.token_count += output_ids.sequences.shape[1]
|
||||||
|
return output_ids
|
||||||
|
|
||||||
|
class TransformersChatEngine:
|
||||||
|
def __init__(self, model_name: str, device: str = None, fp16: bool = True,
|
||||||
|
output_json=True,
|
||||||
|
gpus: int = 1):
|
||||||
|
"""
|
||||||
|
通用大模型加载器,支持 GPU 数量与编号控制
|
||||||
|
:param model_name: 模型名称或路径
|
||||||
|
:param device: 指定设备如 "cuda:0",默认自动选择
|
||||||
|
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU)
|
||||||
|
:param gpus: 使用的 GPU 数量,1 表示单卡,>1 表示多卡推理(使用 device_map='auto')
|
||||||
|
"""
|
||||||
|
self.output_json = output_json
|
||||||
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
|
||||||
|
|
||||||
|
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
|
||||||
|
|
||||||
|
# Tokenizer 加载
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
|
# 模型加载
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
|
||||||
|
device_map="auto" if self.is_multi_gpu else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.is_multi_gpu:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
self.is_chat = is_chat_model(model_name, self.tokenizer)
|
||||||
|
if self.is_chat:
|
||||||
|
self.messages = [ ]
|
||||||
|
|
||||||
|
print(f'{self.model.generation_config=}')
|
||||||
|
|
||||||
|
def set_system_prompt(self, prompt):
|
||||||
|
if self.is_chat:
|
||||||
|
self.messages = [{
|
||||||
|
|
||||||
|
'role': 'system',
|
||||||
|
'content': prompt
|
||||||
|
}]
|
||||||
|
def set_assistant_prompt(self, prompt):
|
||||||
|
if self.is_chat:
|
||||||
|
self.messages.append({
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': prompt
|
||||||
|
})
|
||||||
|
def set_user_prompt(self, prompt):
|
||||||
|
if self.is_chat:
|
||||||
|
self.messages.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': prompt
|
||||||
|
})
|
||||||
|
return build_chat_prompt(self.messages)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def generate(self, prompt: str):
|
||||||
|
t1 = time()
|
||||||
|
prompt = self.set_user_prompt(prompt)
|
||||||
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=128,
|
||||||
|
generation_config=self.model.generation_config
|
||||||
|
)
|
||||||
|
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
t2 = time
|
||||||
|
text = output_text[len(prompt):] if output_text.startswith(prompt) else output_text
|
||||||
|
self.set_assistant_prompt(text)
|
||||||
|
if not self.output_json:
|
||||||
|
return text
|
||||||
|
input_tokens = inputs["input_ids"].shape[1]
|
||||||
|
output_tokens = len(self.tokenizer(text, return_tensors="pt")["input_ids"][0])
|
||||||
|
return {
|
||||||
|
'content':text,
|
||||||
|
'input_tokens': input_tokens,
|
||||||
|
'output_tokens': output_tokens,
|
||||||
|
'finish_time': t2 - t1,
|
||||||
|
'response_time': t2 - t1
|
||||||
|
}
|
||||||
|
|
||||||
|
def stream_generate(self, prompt: str):
|
||||||
|
t1 = time()
|
||||||
|
prompt = self.set_user_prompt(prompt)
|
||||||
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||||
|
input_tokens = inputs["input_ids"].shape[1]
|
||||||
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
generation_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=16000,
|
||||||
|
generation_config=self.model.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
first = True
|
||||||
|
all_txt = ''
|
||||||
|
for new_text in streamer:
|
||||||
|
all_txt += new_text
|
||||||
|
if first:
|
||||||
|
t2 = time()
|
||||||
|
first = False
|
||||||
|
if not self.output_json:
|
||||||
|
yield new_text
|
||||||
|
yield {
|
||||||
|
'content': new_text,
|
||||||
|
'done': False
|
||||||
|
}
|
||||||
|
output_tokens = len(self.tokenizer(all_txt, return_tensors="pt")["input_ids"][0])
|
||||||
|
self.set_assistant_prompt(all_txt)
|
||||||
|
t3 = time()
|
||||||
|
if self.output_json:
|
||||||
|
yield {
|
||||||
|
'done': True,
|
||||||
|
'content':'',
|
||||||
|
'response_time': t2 - t1,
|
||||||
|
'finish_time': t3 - t1,
|
||||||
|
'input_tokens': input_tokens,
|
||||||
|
'output_tokens': output_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Transformers Chat CLI")
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="模型路径或 Hugging Face 名称")
|
||||||
|
parser.add_argument("--gpus", type=int, default=1, help="使用 GPU 数量")
|
||||||
|
parser.add_argument("--stream", action="store_true", help="是否流式输出")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def print_content(outd):
|
||||||
|
if isinstance(outd, dict):
|
||||||
|
print(outd['content'], end="", flush=True)
|
||||||
|
else:
|
||||||
|
print(outd, end="", flush=True)
|
||||||
|
|
||||||
|
def print_info(outd):
|
||||||
|
if isinstance(outd, dict):
|
||||||
|
if outd['done']:
|
||||||
|
print(f"response_time={outd['response_time']}, finish_time={outd['finish_time']}, input_tokens={outd['input_tokens']}, output_tokens={outd['output_tokens']}\n")
|
||||||
|
else:
|
||||||
|
print('\n');
|
||||||
|
|
||||||
|
def generate(engine, stream):
|
||||||
|
while True:
|
||||||
|
print('prompt("q" to exit):')
|
||||||
|
p = input()
|
||||||
|
if p == 'q':
|
||||||
|
break
|
||||||
|
if not p:
|
||||||
|
continue
|
||||||
|
if stream:
|
||||||
|
for outd in engine.stream_generate(p):
|
||||||
|
print_content(outd)
|
||||||
|
print('\n')
|
||||||
|
print_info(outd)
|
||||||
|
else:
|
||||||
|
outd = engine.generate(p)
|
||||||
|
print_content(outd)
|
||||||
|
print('\n')
|
||||||
|
print__info(outd)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
print(f'{args=}')
|
||||||
|
engine = TransformersChatEngine(
|
||||||
|
model_name=args.model,
|
||||||
|
gpus=args.gpus
|
||||||
|
)
|
||||||
|
generate(engine, args.stream)
|
||||||
|
|
||||||
|
main()
|
59
build/lib/llmengine/devstral.py
Normal file
59
build/lib/llmengine/devstral.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# for model mistralai/Devstral-Small-2505
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||||
|
from mistral_common.protocol.instruct.messages import (
|
||||||
|
SystemMessage, UserMessage, AssistantMessage
|
||||||
|
)
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
||||||
|
|
||||||
|
class DevstralLLM(T2TChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
tekken_file = f'{model_id}/tekken.json'
|
||||||
|
self.tokenizer = MistralTokenizer.from_file(tekken_file)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def _build_assistant_message(self, prompt):
|
||||||
|
return AssistantMessage(content=prompt)
|
||||||
|
|
||||||
|
def _build_sys_message(self, prompt):
|
||||||
|
return SystemMessage(content=prompt)
|
||||||
|
|
||||||
|
def _build_user_message(self, prompt, **kw):
|
||||||
|
return UserMessage(content=prompt)
|
||||||
|
|
||||||
|
def get_streamer(self):
|
||||||
|
return TextIteratorStreamer(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
skip_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=32768,
|
||||||
|
do_sample=True
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
tokenized = self.tokenizer.encode_chat_completion(
|
||||||
|
ChatCompletionRequest(messages=messages)
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'input_ids': torch.tensor([tokenized.tokens])
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_register('mistralai/Devstral', DevstralLLM)
|
||||||
|
|
95
build/lib/llmengine/embedding.py
Normal file
95
build/lib/llmengine/embedding.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.qwen3embedding import *
|
||||||
|
from llmengine.base_embedding import get_llm_class
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
|
helptext = """embeddings api:
|
||||||
|
path: /v1/embeddings
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"input": "this is a test"
|
||||||
|
}
|
||||||
|
or {
|
||||||
|
"input":[
|
||||||
|
"this is first sentence",
|
||||||
|
"this is second setence"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
response is a json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [0.0123, -0.0456, ...]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('embeddings', embeddings)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def embeddings(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.input=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.embeddings)
|
||||||
|
input = params_kw.input
|
||||||
|
if input is None:
|
||||||
|
e = exception(f'input is None')
|
||||||
|
raise e
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
arr = await f(input)
|
||||||
|
debug(f'{arr=}, type(arr)')
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="Embedding")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
se.engine.use_mps_if_prosible()
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
87
build/lib/llmengine/entity.py
Normal file
87
build/lib/llmengine/entity.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.ltpentity import *
|
||||||
|
from llmengine.base_entity import get_ltp_class
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
|
helptext = """LTP Entities API:
|
||||||
|
|
||||||
|
1. Entities Endpoint:
|
||||||
|
path: /v1/entities
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"query": "苹果公司在北京开设新店"
|
||||||
|
}
|
||||||
|
response: {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
"苹果公司",
|
||||||
|
"北京",
|
||||||
|
"新店",
|
||||||
|
"开设",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
2. Docs Endpoint:
|
||||||
|
path: /v1/docs
|
||||||
|
response: This help text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('entities', entities)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def entities(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.query=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.extract_entities)
|
||||||
|
query = params_kw.query
|
||||||
|
if query is None:
|
||||||
|
e = exception(f'query is None')
|
||||||
|
raise e
|
||||||
|
entities = await f(query)
|
||||||
|
debug(f'{entities=}, type(entities)')
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": entities
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="LTP Entity Service")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_ltp_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
44
build/lib/llmengine/gemma3_it.py
Normal file
44
build/lib/llmengine/gemma3_it.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
|
||||||
|
# pip install accelerate
|
||||||
|
import threading
|
||||||
|
from time import time
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||||
|
|
||||||
|
class Gemma3LLM(MMChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, device_map="auto"
|
||||||
|
).eval()
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
self.messages = []
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
llm_register("gemma-3", Gemma3LLM)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gemma3 = Gemma3LLM('/share/models/google/gemma-3-4b-it')
|
||||||
|
session = {}
|
||||||
|
while True:
|
||||||
|
print('input prompt')
|
||||||
|
p = input()
|
||||||
|
if p:
|
||||||
|
if p == 'q':
|
||||||
|
break;
|
||||||
|
print('input image path')
|
||||||
|
imgpath=input()
|
||||||
|
for d in gemma3.stream_generate(session, p, image_path=imgpath):
|
||||||
|
if not d['done']:
|
||||||
|
print(d['text'], end='', flush=True)
|
||||||
|
else:
|
||||||
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
print(f'\n{x}\n')
|
||||||
|
|
||||||
|
|
76
build/lib/llmengine/ltpentity.py
Normal file
76
build/lib/llmengine/ltpentity.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Requires ltp>=0.2.0
|
||||||
|
|
||||||
|
from ltp import LTP
|
||||||
|
from typing import List
|
||||||
|
import logging
|
||||||
|
from llmengine.base_entity import BaseLtp, ltp_register
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LtpEntity(BaseLtp):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
# Load LTP model for CWS, POS, and NER
|
||||||
|
self.ltp = LTP(model_id)
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
|
def extract_entities(self, query: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
从查询文本中抽取实体,包括:
|
||||||
|
- LTP NER 识别的实体(所有类型)。
|
||||||
|
- LTP POS 标注为名词('n')的词。
|
||||||
|
- LTP POS 标注为动词('v')的词。
|
||||||
|
- 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
|
||||||
|
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
||||||
|
words = result.cws[0]
|
||||||
|
pos_list = result.pos[0]
|
||||||
|
ner = result.ner[0]
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
subword_set = set()
|
||||||
|
|
||||||
|
logger.debug(f"NER 结果: {ner}")
|
||||||
|
for entity_type, entity, start, end in ner:
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
for i in range(len(words)):
|
||||||
|
if pos_list[i] == 'n':
|
||||||
|
combined += words[i]
|
||||||
|
combined_words.append(words[i])
|
||||||
|
if i + 1 < len(words) and pos_list[i + 1] == 'n':
|
||||||
|
continue
|
||||||
|
if combined:
|
||||||
|
entities.append(combined)
|
||||||
|
subword_set.update(combined_words)
|
||||||
|
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
else:
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
logger.debug(f"连续名词子词集合: {subword_set}")
|
||||||
|
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'n' and word not in subword_set:
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'v':
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
unique_entities = list(dict.fromkeys(entities))
|
||||||
|
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||||
|
return unique_entities
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"实体抽取失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
ltp_register('LTP', LtpEntity)
|
53
build/lib/llmengine/medgemma3_it.py
Normal file
53
build/lib/llmengine/medgemma3_it.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# pip install accelerate
|
||||||
|
import time
|
||||||
|
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import MMChatLLM, llm_register
|
||||||
|
|
||||||
|
model_id = "google/medgemma-4b-it"
|
||||||
|
|
||||||
|
class MedgemmaLLM(MMChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.model.device, dtype=torch.bfloat16)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
llm_register("google/medgemma", MedgemmaLLM)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
med = MedgemmaLLM('/share/models/google/medgemma-4b-it')
|
||||||
|
session = {}
|
||||||
|
while True:
|
||||||
|
print(f'chat with {med.model_id}')
|
||||||
|
print('input prompt')
|
||||||
|
p = input()
|
||||||
|
if p:
|
||||||
|
if p == 'q':
|
||||||
|
break;
|
||||||
|
print('input image path')
|
||||||
|
imgpath=input()
|
||||||
|
for d in med.stream_generate(session, p, image_path=imgpath):
|
||||||
|
if not d['done']:
|
||||||
|
print(d['text'], end='', flush=True)
|
||||||
|
else:
|
||||||
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
print(f'\n{x}\n')
|
||||||
|
|
||||||
|
|
68
build/lib/llmengine/qwen3.py
Normal file
68
build/lib/llmengine/qwen3.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#!/share/vllm-0.8.5/bin/python
|
||||||
|
|
||||||
|
# pip install accelerate
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug
|
||||||
|
from ahserver.serverenv import get_serverenv
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM, T2TChatLLM, llm_register
|
||||||
|
|
||||||
|
class Qwen3LLM(T2TChatLLM):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = torch.device("mps")
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
self.model_id = model_id
|
||||||
|
|
||||||
|
def build_kwargs(self, inputs, streamer):
|
||||||
|
generate_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=32768,
|
||||||
|
do_sample=True,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
return generate_kwargs
|
||||||
|
|
||||||
|
def _messages2inputs(self, messages):
|
||||||
|
debug(f'{messages=}')
|
||||||
|
text = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True
|
||||||
|
)
|
||||||
|
return self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
|
llm_register("Qwen/Qwen3", Qwen3LLM)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
model_path = sys.argv[1]
|
||||||
|
q3 = Qwen3LLM(model_path)
|
||||||
|
session = {}
|
||||||
|
while True:
|
||||||
|
print('input prompt')
|
||||||
|
p = input()
|
||||||
|
if p:
|
||||||
|
if p == 'q':
|
||||||
|
break;
|
||||||
|
for d in q3.stream_generate(session, p):
|
||||||
|
print(d)
|
||||||
|
"""
|
||||||
|
if not d['done']:
|
||||||
|
print(d['text'], end='', flush=True)
|
||||||
|
else:
|
||||||
|
x = {k:v for k,v in d.items() if k != 'text'}
|
||||||
|
print(f'\n{x}\n')
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
16
build/lib/llmengine/qwen3_reranker.py
Normal file
16
build/lib/llmengine/qwen3_reranker.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from llmengine.base_reranker import BaseReranker, llm_register
|
||||||
|
|
||||||
|
class Qwen3Reranker(BaseReranker):
|
||||||
|
def __init__(self, model_id, max_length=8096):
|
||||||
|
if 'Qwen3-Reranker' not in model_id:
|
||||||
|
e = Exception(f'{model_id} is not a Qwen3-Reranker')
|
||||||
|
raise e
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_id).eval()
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
self.max_length = 8192
|
||||||
|
|
||||||
|
llm_register('Qwen3-Reranker', Qwen3Reranker)
|
22
build/lib/llmengine/qwen3embedding.py
Normal file
22
build/lib/llmengine/qwen3embedding.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Requires transformers>=4.51.0
|
||||||
|
# Requires sentence-transformers>=2.7.0
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from llmengine.base_embedding import BaseEmbedding, llm_register
|
||||||
|
|
||||||
|
class Qwen3Embedding(BaseEmbedding):
|
||||||
|
def __init__(self, model_id, max_length=8096):
|
||||||
|
# Load the model
|
||||||
|
self.model = SentenceTransformer(model_id)
|
||||||
|
# We recommend enabling flash_attention_2 for better acceleration and memory saving,
|
||||||
|
# together with setting `padding_side` to "left":
|
||||||
|
# model = SentenceTransformer(
|
||||||
|
# "Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
# model_kwargs={"attn_implementation": "flash_attention_2", "device_map": "auto"},
|
||||||
|
# tokenizer_kwargs={"padding_side": "left"},
|
||||||
|
# )
|
||||||
|
self.max_length = max_length
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
|
llm_register('Qwen3-Embedding', Qwen3Embedding)
|
106
build/lib/llmengine/rerank.py
Normal file
106
build/lib/llmengine/rerank.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.qwen3_reranker import *
|
||||||
|
from llmengine.bge_reranker import *
|
||||||
|
from llmengine.base_reranker import get_llm_class
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
helptext = """rerank api:
|
||||||
|
path: /v1/rerank
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data:
|
||||||
|
{
|
||||||
|
"model": "rerank-001",
|
||||||
|
"query": "什么是量子计算?",
|
||||||
|
"documents": [
|
||||||
|
"量子计算是一种使用量子比特进行计算的方式。",
|
||||||
|
"古典计算机使用的是二进制位。",
|
||||||
|
"天气预报依赖于统计模型。",
|
||||||
|
"量子计算与物理学密切相关。"
|
||||||
|
},
|
||||||
|
"top_n": 2
|
||||||
|
}
|
||||||
|
|
||||||
|
response is a json
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.95
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 3,
|
||||||
|
"relevance_score": 0.89
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"object": "rerank.result",
|
||||||
|
"model": "rerank-001",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('rerank', rerank)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def rerank(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.query=}, {params_kw.documents=}, {params_kw.top_n=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.rerank)
|
||||||
|
query = params_kw.query
|
||||||
|
if query is None:
|
||||||
|
e = Exception(f'query is None')
|
||||||
|
raise e
|
||||||
|
documents = params_kw.documents
|
||||||
|
if documents is None:
|
||||||
|
e = Exception(f'documents is None')
|
||||||
|
raise e
|
||||||
|
if isinstance(documents, str):
|
||||||
|
documents = [documents]
|
||||||
|
top_n = params_kw.top_n
|
||||||
|
if top_n is None:
|
||||||
|
top_n = 5
|
||||||
|
arr = await f(query, params_kw.documents, top_n)
|
||||||
|
debug(f'{arr=}, type(arr)')
|
||||||
|
return arr
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="Rerank")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
se.engine.use_mps_if_prosible()
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
62
build/lib/llmengine/server.py
Normal file
62
build/lib/llmengine/server.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llmengine.base_chat_llm import BaseChatLLM, get_llm_class
|
||||||
|
from llmengine.gemma3_it import Gemma3LLM
|
||||||
|
from llmengine.medgemma3_it import MedgemmaLLM
|
||||||
|
from llmengine.qwen3 import Qwen3LLM
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('chat_completions', chat_completions)
|
||||||
|
|
||||||
|
async def chat_completions(request, params_kw, *params, **kw):
|
||||||
|
async def gor():
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.chat_engine
|
||||||
|
session = await get_session(request)
|
||||||
|
kwargs = {
|
||||||
|
}
|
||||||
|
if params_kw.image_path:
|
||||||
|
kwargs['image_path'] = fs.reapPath(params_kw.image_path)
|
||||||
|
if params_kw.video_path:
|
||||||
|
kwargs['video_path'] = fs.reapPath(params_kw.video_path)
|
||||||
|
if params_kw.audio_path:
|
||||||
|
kwargs['audio_path'] = fs.reapPath(params_kw.audio_path)
|
||||||
|
async for d in engine.async_stream_generate(session, params_kw.prompt, **kwargs):
|
||||||
|
debug(f'{d=}')
|
||||||
|
yield d
|
||||||
|
|
||||||
|
return await stream_response(request, gor)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="Sage")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_llm_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
se.engine.use_mps_if_prosible()
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
17
llmengine.egg-info/PKG-INFO
Normal file
17
llmengine.egg-info/PKG-INFO
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
Metadata-Version: 2.4
|
||||||
|
Name: llmengine
|
||||||
|
Version: 0.0.1
|
||||||
|
Summary: Your project description
|
||||||
|
Author-email: yu moqing <yumoqing@gmail.com>
|
||||||
|
License: MIT
|
||||||
|
Requires-Python: >=3.8
|
||||||
|
Description-Content-Type: text/markdown
|
||||||
|
Requires-Dist: torch
|
||||||
|
Requires-Dist: transformers
|
||||||
|
Requires-Dist: sentence-transformers>=2.7.0
|
||||||
|
Requires-Dist: mistral-common
|
||||||
|
Requires-Dist: accelerate
|
||||||
|
Provides-Extra: dev
|
||||||
|
Requires-Dist: pytest; extra == "dev"
|
||||||
|
Requires-Dist: black; extra == "dev"
|
||||||
|
Requires-Dist: mypy; extra == "dev"
|
26
llmengine.egg-info/SOURCES.txt
Normal file
26
llmengine.egg-info/SOURCES.txt
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
README.md
|
||||||
|
pyproject.toml
|
||||||
|
llmengine/__init__.py
|
||||||
|
llmengine/ahserver.py
|
||||||
|
llmengine/base_chat_llm.py
|
||||||
|
llmengine/base_embedding.py
|
||||||
|
llmengine/base_entity.py
|
||||||
|
llmengine/base_reranker.py
|
||||||
|
llmengine/bge_reranker.py
|
||||||
|
llmengine/chatllm.py
|
||||||
|
llmengine/devstral.py
|
||||||
|
llmengine/embedding.py
|
||||||
|
llmengine/entity.py
|
||||||
|
llmengine/gemma3_it.py
|
||||||
|
llmengine/ltpentity.py
|
||||||
|
llmengine/medgemma3_it.py
|
||||||
|
llmengine/qwen3.py
|
||||||
|
llmengine/qwen3_reranker.py
|
||||||
|
llmengine/qwen3embedding.py
|
||||||
|
llmengine/rerank.py
|
||||||
|
llmengine/server.py
|
||||||
|
llmengine.egg-info/PKG-INFO
|
||||||
|
llmengine.egg-info/SOURCES.txt
|
||||||
|
llmengine.egg-info/dependency_links.txt
|
||||||
|
llmengine.egg-info/requires.txt
|
||||||
|
llmengine.egg-info/top_level.txt
|
1
llmengine.egg-info/dependency_links.txt
Normal file
1
llmengine.egg-info/dependency_links.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
10
llmengine.egg-info/requires.txt
Normal file
10
llmengine.egg-info/requires.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
torch
|
||||||
|
transformers
|
||||||
|
sentence-transformers>=2.7.0
|
||||||
|
mistral-common
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
[dev]
|
||||||
|
pytest
|
||||||
|
black
|
||||||
|
mypy
|
1
llmengine.egg-info/top_level.txt
Normal file
1
llmengine.egg-info/top_level.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
llmengine
|
0
llmengine/__init__.py
Normal file
0
llmengine/__init__.py
Normal file
BIN
llmengine/__pycache__/ahserver.cpython-310.pyc
Normal file
BIN
llmengine/__pycache__/ahserver.cpython-310.pyc
Normal file
Binary file not shown.
BIN
llmengine/__pycache__/base_entity.cpython-310.pyc
Normal file
BIN
llmengine/__pycache__/base_entity.cpython-310.pyc
Normal file
Binary file not shown.
BIN
llmengine/__pycache__/entity.cpython-310.pyc
Normal file
BIN
llmengine/__pycache__/entity.cpython-310.pyc
Normal file
Binary file not shown.
BIN
llmengine/__pycache__/ltpentity.cpython-310.pyc
Normal file
BIN
llmengine/__pycache__/ltpentity.cpython-310.pyc
Normal file
Binary file not shown.
5
llmengine/ahserver.py
Normal file
5
llmengine/ahserver.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from ahserver.configuredServer import ConfiguredServer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = ConfiguredServer()
|
||||||
|
server.run()
|
23
llmengine/base_entity.py
Normal file
23
llmengine/base_entity.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
model_pathMap = {}
|
||||||
|
|
||||||
|
def ltp_register(model_key, Klass):
|
||||||
|
"""Register a model class for a given model key."""
|
||||||
|
global model_pathMap
|
||||||
|
model_pathMap[model_key] = Klass
|
||||||
|
|
||||||
|
def get_ltp_class(model_path):
|
||||||
|
"""Find the model class for a given model path."""
|
||||||
|
for k, klass in model_pathMap.items():
|
||||||
|
if len(model_path.split(k)) > 1:
|
||||||
|
return klass
|
||||||
|
print(f'{model_pathMap=}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
class BaseLtp(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def extract_entities(self, query: str) -> List[str]:
|
||||||
|
"""Extract entities from query text."""
|
||||||
|
pass
|
87
llmengine/entity.py
Normal file
87
llmengine/entity.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from traceback import format_exc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from llmengine.ltpentity import *
|
||||||
|
from llmengine.base_entity import get_ltp_class
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from appPublic.registerfunction import RegisterFunction
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, exception
|
||||||
|
from ahserver.serverenv import ServerEnv
|
||||||
|
from ahserver.globalEnv import stream_response
|
||||||
|
from ahserver.webapp import webserver
|
||||||
|
|
||||||
|
from aiohttp_session import get_session
|
||||||
|
|
||||||
|
helptext = """LTP Entities API:
|
||||||
|
|
||||||
|
1. Entities Endpoint:
|
||||||
|
path: /v1/entities
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data: {
|
||||||
|
"query": "苹果公司在北京开设新店"
|
||||||
|
}
|
||||||
|
response: {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
"苹果公司",
|
||||||
|
"北京",
|
||||||
|
"新店",
|
||||||
|
"开设",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
2. Docs Endpoint:
|
||||||
|
path: /v1/docs
|
||||||
|
response: This help text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init():
|
||||||
|
rf = RegisterFunction()
|
||||||
|
rf.register('entities', entities)
|
||||||
|
rf.register('docs', docs)
|
||||||
|
|
||||||
|
async def docs(request, params_kw, *params, **kw):
|
||||||
|
return helptext
|
||||||
|
|
||||||
|
async def entities(request, params_kw, *params, **kw):
|
||||||
|
debug(f'{params_kw.query=}')
|
||||||
|
se = ServerEnv()
|
||||||
|
engine = se.engine
|
||||||
|
f = awaitify(engine.extract_entities)
|
||||||
|
query = params_kw.query
|
||||||
|
if query is None:
|
||||||
|
e = exception(f'query is None')
|
||||||
|
raise e
|
||||||
|
entities = await f(query)
|
||||||
|
debug(f'{entities=}, type(entities)')
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": entities
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="LTP Entity Service")
|
||||||
|
parser.add_argument('-w', '--workdir')
|
||||||
|
parser.add_argument('-p', '--port')
|
||||||
|
parser.add_argument('model_path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
Klass = get_ltp_class(args.model_path)
|
||||||
|
if Klass is None:
|
||||||
|
e = Exception(f'{args.model_path} has not mapping to a model class')
|
||||||
|
exception(f'{e}, {format_exc()}')
|
||||||
|
raise e
|
||||||
|
se = ServerEnv()
|
||||||
|
se.engine = Klass(args.model_path)
|
||||||
|
workdir = args.workdir or os.getcwd()
|
||||||
|
port = args.port
|
||||||
|
debug(f'{args=}')
|
||||||
|
webserver(init, workdir, port)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
76
llmengine/ltpentity.py
Normal file
76
llmengine/ltpentity.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Requires ltp>=0.2.0
|
||||||
|
|
||||||
|
from ltp import LTP
|
||||||
|
from typing import List
|
||||||
|
import logging
|
||||||
|
from llmengine.base_entity import BaseLtp, ltp_register
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LtpEntity(BaseLtp):
|
||||||
|
def __init__(self, model_id):
|
||||||
|
# Load LTP model for CWS, POS, and NER
|
||||||
|
self.ltp = LTP(model_id)
|
||||||
|
self.model_id = model_id
|
||||||
|
self.model_name = model_id.split('/')[-1]
|
||||||
|
|
||||||
|
def extract_entities(self, query: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
从查询文本中抽取实体,包括:
|
||||||
|
- LTP NER 识别的实体(所有类型)。
|
||||||
|
- LTP POS 标注为名词('n')的词。
|
||||||
|
- LTP POS 标注为动词('v')的词。
|
||||||
|
- 连续名词合并(如 '苹果 公司' -> '苹果公司'),移除子词。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not query:
|
||||||
|
raise ValueError("查询文本不能为空")
|
||||||
|
|
||||||
|
result = self.ltp.pipeline([query], tasks=["cws", "pos", "ner"])
|
||||||
|
words = result.cws[0]
|
||||||
|
pos_list = result.pos[0]
|
||||||
|
ner = result.ner[0]
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
subword_set = set()
|
||||||
|
|
||||||
|
logger.debug(f"NER 结果: {ner}")
|
||||||
|
for entity_type, entity, start, end in ner:
|
||||||
|
entities.append(entity)
|
||||||
|
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
for i in range(len(words)):
|
||||||
|
if pos_list[i] == 'n':
|
||||||
|
combined += words[i]
|
||||||
|
combined_words.append(words[i])
|
||||||
|
if i + 1 < len(words) and pos_list[i + 1] == 'n':
|
||||||
|
continue
|
||||||
|
if combined:
|
||||||
|
entities.append(combined)
|
||||||
|
subword_set.update(combined_words)
|
||||||
|
logger.debug(f"合并连续名词: {combined}, 子词: {combined_words}")
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
else:
|
||||||
|
combined = ""
|
||||||
|
combined_words = []
|
||||||
|
logger.debug(f"连续名词子词集合: {subword_set}")
|
||||||
|
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'n' and word not in subword_set:
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
for word, pos in zip(words, pos_list):
|
||||||
|
if pos == 'v':
|
||||||
|
entities.append(word)
|
||||||
|
|
||||||
|
unique_entities = list(dict.fromkeys(entities))
|
||||||
|
logger.info(f"从查询中提取到 {len(unique_entities)} 个唯一实体: {unique_entities}")
|
||||||
|
return unique_entities
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"实体抽取失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
ltp_register('LTP', LtpEntity)
|
50
test/entities/conf/config.json
Normal file
50
test/entities/conf/config.json
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
{
|
||||||
|
"filesroot":"$[workdir]$/files",
|
||||||
|
"logger":{
|
||||||
|
"name":"llmengine",
|
||||||
|
"levelname":"info",
|
||||||
|
"logfile":"$[workdir]$/logs/llmengine.log"
|
||||||
|
},
|
||||||
|
"website":{
|
||||||
|
"paths":[
|
||||||
|
["$[workdir]$/wwwroot",""]
|
||||||
|
],
|
||||||
|
"client_max_size":10000,
|
||||||
|
"host":"0.0.0.0",
|
||||||
|
"port":9990,
|
||||||
|
"coding":"utf-8",
|
||||||
|
"indexes":[
|
||||||
|
"index.html",
|
||||||
|
"index.ui"
|
||||||
|
],
|
||||||
|
"startswiths":[
|
||||||
|
{
|
||||||
|
"leading":"/idfile",
|
||||||
|
"registerfunction":"idfile"
|
||||||
|
},{
|
||||||
|
"leading": "/v1/entities",
|
||||||
|
"registerfunction": "entities"
|
||||||
|
},{
|
||||||
|
"leading": "/docs",
|
||||||
|
"registerfunction": "docs"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"processors":[
|
||||||
|
[".tmpl","tmpl"],
|
||||||
|
[".app","app"],
|
||||||
|
[".ui","bui"],
|
||||||
|
[".dspy","dspy"],
|
||||||
|
[".md","md"]
|
||||||
|
],
|
||||||
|
"rsakey_oops":{
|
||||||
|
"privatekey":"$[workdir]$/conf/rsa_private_key.pem",
|
||||||
|
"publickey":"$[workdir]$/conf/rsa_public_key.pem"
|
||||||
|
},
|
||||||
|
"session_max_time":3000,
|
||||||
|
"session_issue_time":2500,
|
||||||
|
"session_redis_notuse":{
|
||||||
|
"url":"redis://127.0.0.1:6379"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
0
test/entities/logs/llmengine.log
Normal file
0
test/entities/logs/llmengine.log
Normal file
3
test/entities/start.sh
Executable file
3
test/entities/start.sh
Executable file
@ -0,0 +1,3 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
/share/vllm-0.8.5/bin/python -m llmengine.entity -p 9990 /share/models/LTP/small
|
Loading…
Reference in New Issue
Block a user