llmengine/chatllm.py
2025-05-29 22:17:12 +08:00

65 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
class TransformersChatEngine:
def __init__(self, model_name: str, device: str = None, fp16: bool = True, gpus: int = 1):
"""
通用大模型加载器,支持 GPU 数量与编号控制
:param model_name: 模型名称或路径
:param device: 指定设备如 "cuda:0",默认自动选择
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU
:param gpus: 使用的 GPU 数量1 表示单卡,>1 表示多卡推理(使用 device_map='auto'
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
# Tokenizer 加载
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# 模型加载
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
device_map="auto" if self.is_multi_gpu else None
)
if not self.is_multi_gpu:
self.model.to(self.device)
self.model.eval()
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, stop: str = None) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.tokenizer.eos_token_id
)
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text[len(prompt):] if output_text.startswith(prompt) else output_text
def stream_generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.tokenizer.eos_token_id
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text