first commit
This commit is contained in:
commit
7376c939ee
64
chatllm.py
Normal file
64
chatllm.py
Normal file
@ -0,0 +1,64 @@
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||
import torch
|
||||
from threading import Thread
|
||||
|
||||
class TransformersChatEngine:
|
||||
def __init__(self, model_name: str, device: str = None, fp16: bool = True, gpus: int = 1):
|
||||
"""
|
||||
通用大模型加载器,支持 GPU 数量与编号控制
|
||||
:param model_name: 模型名称或路径
|
||||
:param device: 指定设备如 "cuda:0",默认自动选择
|
||||
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU)
|
||||
:param gpus: 使用的 GPU 数量,1 表示单卡,>1 表示多卡推理(使用 device_map='auto')
|
||||
"""
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
|
||||
|
||||
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
|
||||
|
||||
# Tokenizer 加载
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
|
||||
# 模型加载
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
|
||||
device_map="auto" if self.is_multi_gpu else None
|
||||
)
|
||||
|
||||
if not self.is_multi_gpu:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, stop: str = None) -> str:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
eos_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
return output_text[len(prompt):] if output_text.startswith(prompt) else output_text
|
||||
|
||||
def stream_generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7):
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
generation_kwargs = dict(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
eos_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
for new_text in streamer:
|
||||
yield new_text
|
||||
|
21
pyproject.toml
Normal file
21
pyproject.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[project]
|
||||
name="llmengine"
|
||||
version = "0.0.1"
|
||||
description = "Your project description"
|
||||
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {text = "MIT"}
|
||||
dependencies = [
|
||||
"torch",
|
||||
"tramsformers",
|
||||
"acelerate"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest", "black", "mypy"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
Loading…
Reference in New Issue
Block a user