first commit
This commit is contained in:
commit
7376c939ee
64
chatllm.py
Normal file
64
chatllm.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||||
|
import torch
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
class TransformersChatEngine:
|
||||||
|
def __init__(self, model_name: str, device: str = None, fp16: bool = True, gpus: int = 1):
|
||||||
|
"""
|
||||||
|
通用大模型加载器,支持 GPU 数量与编号控制
|
||||||
|
:param model_name: 模型名称或路径
|
||||||
|
:param device: 指定设备如 "cuda:0",默认自动选择
|
||||||
|
:param fp16: 是否使用 fp16 精度(适用于支持的 GPU)
|
||||||
|
:param gpus: 使用的 GPU 数量,1 表示单卡,>1 表示多卡推理(使用 device_map='auto')
|
||||||
|
"""
|
||||||
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.is_multi_gpu = gpus > 1 and torch.cuda.device_count() >= gpus
|
||||||
|
|
||||||
|
print(f"✅ Using device: {self.device}, GPUs: {gpus}, Multi-GPU: {self.is_multi_gpu}")
|
||||||
|
|
||||||
|
# Tokenizer 加载
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
|
# 模型加载
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if fp16 and "cuda" in self.device else torch.float32,
|
||||||
|
device_map="auto" if self.is_multi_gpu else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.is_multi_gpu:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, stop: str = None) -> str:
|
||||||
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=temperature,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
return output_text[len(prompt):] if output_text.startswith(prompt) else output_text
|
||||||
|
|
||||||
|
def stream_generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7):
|
||||||
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||||
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
generation_kwargs = dict(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=temperature,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
yield new_text
|
||||||
|
|
21
pyproject.toml
Normal file
21
pyproject.toml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[project]
|
||||||
|
name="llmengine"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "Your project description"
|
||||||
|
authors = [{ name = "yu moqing", email = "yumoqing@gmail.com" }]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
"tramsformers",
|
||||||
|
"acelerate"
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = ["pytest", "black", "mypy"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user