65 lines
2.6 KiB
Python
65 lines
2.6 KiB
Python
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||
import torch
|
||
from threading import Thread
|
||
|
||
class TransformersChatEngine:
|
||
def __init__(self, model_name: str, device: str = None, fp16: bool = True, gpus: int = 1):
|
||
"""
|
||
通用大模型加载器,支持 GPU 数量与编号控制
|
||
:param model_name: 模型名称或路径
|
||
:param device: 指定设备如 "cuda:0",默认自动选择
|
||
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU)
|
||
:param gpus: 使用的 GPU 数量,1 表示单卡,>1 表示多卡推理(使用 device_map='auto')
|
||
"""
|
||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
|
||
|
||
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
|
||
|
||
# Tokenizer 加载
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||
|
||
# 模型加载
|
||
self.model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
|
||
device_map="auto" if self.is_multi_gpu else None
|
||
)
|
||
|
||
if not self.is_multi_gpu:
|
||
self.model.to(self.device)
|
||
|
||
self.model.eval()
|
||
|
||
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, stop: str = None) -> str:
|
||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||
output_ids = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=max_tokens,
|
||
do_sample=True,
|
||
temperature=temperature,
|
||
eos_token_id=self.tokenizer.eos_token_id
|
||
)
|
||
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||
return output_text[len(prompt):] if output_text.startswith(prompt) else output_text
|
||
|
||
def stream_generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7):
|
||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||
|
||
generation_kwargs = dict(
|
||
**inputs,
|
||
streamer=streamer,
|
||
max_new_tokens=max_tokens,
|
||
do_sample=True,
|
||
temperature=temperature,
|
||
eos_token_id=self.tokenizer.eos_token_id
|
||
)
|
||
|
||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||
thread.start()
|
||
|
||
for new_text in streamer:
|
||
yield new_text
|
||
|