first commit

This commit is contained in:
yumoqing 2025-05-29 22:17:12 +08:00
commit 7376c939ee
3 changed files with 85 additions and 0 deletions

0
README.md Normal file
View File

64
chatllm.py Normal file
View File

@ -0,0 +1,64 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
class TransformersChatEngine:
def __init__(self, model_name: str, device: str = None, fp16: bool = True, gpus: int = 1):
"""
通用大模型加载器支持 GPU 数量与编号控制
:param model_name: 模型名称或路径
:param device: 指定设备如 "cuda:0"默认自动选择
:param fp16: 是否使用 fp16 精度适用于支持的 GPU
:param gpus: 使用的 GPU 数量1 表示单卡>1 表示多卡推理使用 device_map='auto'
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
# Tokenizer 加载
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# 模型加载
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
device_map="auto" if self.is_multi_gpu else None
)
if not self.is_multi_gpu:
self.model.to(self.device)
self.model.eval()
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, stop: str = None) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.tokenizer.eos_token_id
)
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text[len(prompt):] if output_text.startswith(prompt) else output_text
def stream_generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=self.tokenizer.eos_token_id
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text

21
pyproject.toml Normal file
View File

@ -0,0 +1,21 @@
[project]
name="llmengine"
version = "0.0.1"
description = "Your project description"
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
dependencies = [
"torch",
"tramsformers",
"acelerate"
]
[project.optional-dependencies]
dev = ["pytest", "black", "mypy"]
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"