bugfix
This commit is contained in:
parent
f468757690
commit
42c8b12b17
28
callbacks.py
Normal file
28
callbacks.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
from transformers import TrainerCallback, EarlyStoppingCallback
|
||||||
|
import os
|
||||||
|
|
||||||
|
class SaveBestModelCallback(TrainerCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self.best_metric = None
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, **kwargs):
|
||||||
|
metrics = kwargs.get("metrics", {})
|
||||||
|
eval_loss = metrics.get("eval_loss")
|
||||||
|
if eval_loss is None:
|
||||||
|
return
|
||||||
|
if self.best_metric is None or eval_loss < self.best_metric:
|
||||||
|
print(f"🌟 Best model updated: {self.best_metric} -> {eval_loss}")
|
||||||
|
self.best_metric = eval_loss
|
||||||
|
best_model_path = os.path.join(args.output_dir, "best_model")
|
||||||
|
kwargs["model"].save_pretrained(best_model_path)
|
||||||
|
kwargs["tokenizer"].save_pretrained(best_model_path)
|
||||||
|
|
||||||
|
def build_callbacks(config):
|
||||||
|
return [
|
||||||
|
EarlyStoppingCallback(
|
||||||
|
early_stopping_patience=config.early_stopping_patience,
|
||||||
|
early_stopping_threshold=config.early_stopping_threshold,
|
||||||
|
),
|
||||||
|
SaveBestModelCallback(),
|
||||||
|
]
|
30
config.yaml
Normal file
30
config.yaml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
model_name_or_path: "your-model-name"
|
||||||
|
output_dir: "./output"
|
||||||
|
dataset_path: "ds_file_path.json"
|
||||||
|
validation_split_ratio: 0.1
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
num_train_epochs: 3
|
||||||
|
learning_rate: 1e-5
|
||||||
|
lr_scheduler_type: "linear"
|
||||||
|
warmup_steps: 500
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.1
|
||||||
|
target_modules: ["qkv", "query", "key", "value"]
|
||||||
|
quantization_enable: true
|
||||||
|
load_in_4bit: true
|
||||||
|
bnb_4bit_compute_dtype: "float16"
|
||||||
|
bnb_4bit_quant_type: "nf4"
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
logging_steps: 200
|
||||||
|
save_steps: 1000
|
||||||
|
evaluation_strategy: "steps"
|
||||||
|
eval_steps: 500
|
||||||
|
save_strategy: "steps"
|
||||||
|
save_total_limit: 5
|
||||||
|
max_seq_length: 512
|
||||||
|
early_stopping_patience: 3
|
||||||
|
early_stopping_threshold: 0.001
|
||||||
|
deepspeed: "ds_config_zero2.json"
|
41
configs.py
Normal file
41
configs.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
|
||||||
|
import yaml
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
model_name_or_path: str
|
||||||
|
output_dir: str
|
||||||
|
dataset_path: str
|
||||||
|
validation_split_ratio: float
|
||||||
|
per_device_train_batch_size: int
|
||||||
|
gradient_accumulation_steps: int
|
||||||
|
num_train_epochs: int
|
||||||
|
learning_rate: float
|
||||||
|
lr_scheduler_type: str
|
||||||
|
warmup_steps: int
|
||||||
|
lora_r: int
|
||||||
|
lora_alpha: int
|
||||||
|
lora_dropout: float
|
||||||
|
target_modules: list
|
||||||
|
quantization_enable: bool
|
||||||
|
load_in_4bit: bool
|
||||||
|
bnb_4bit_compute_dtype: str
|
||||||
|
bnb_4bit_quant_type: str
|
||||||
|
bnb_4bit_use_double_quant: bool
|
||||||
|
logging_steps: int
|
||||||
|
save_steps: int
|
||||||
|
evaluation_strategy: str
|
||||||
|
eval_steps: int
|
||||||
|
save_strategy: str
|
||||||
|
save_total_limit: int
|
||||||
|
max_seq_length: int
|
||||||
|
early_stopping_patience: int
|
||||||
|
early_stopping_threshold: float
|
||||||
|
deepspeed: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, file_path):
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
return cls(**config_dict)
|
21
dataset.py
Normal file
21
dataset.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
class DatasetManager:
|
||||||
|
def __init__(self, config, tokenizer):
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def preprocess_function(self, examples, max_length):
|
||||||
|
return self.tokenizer(examples["input"], max_length=max_length, truncation=True)
|
||||||
|
|
||||||
|
def load_data(self):
|
||||||
|
dataset = load_dataset("json", data_files=self.config.dataset_path, split="train")
|
||||||
|
dataset = dataset.shuffle(seed=42)
|
||||||
|
split_idx = int(len(dataset) * (1.0 - self.config.validation_split_ratio))
|
||||||
|
train_dataset = dataset.select(range(split_idx))
|
||||||
|
eval_dataset = dataset.select(range(split_idx, len(dataset)))
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
|
||||||
|
eval_dataset = eval_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
|
||||||
|
return train_dataset, eval_dataset
|
51
ds_config_zero2.json
Normal file
51
ds_config_zero2.json
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
|
||||||
|
{
|
||||||
|
"train_batch_size": 32,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": true,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"reduce_scatter": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": true,
|
||||||
|
"loss_scale": 0
|
||||||
|
},
|
||||||
|
"activation_checkpointing": {
|
||||||
|
"partition_activations": true,
|
||||||
|
"contiguous_memory_optimization": true,
|
||||||
|
"num_checkpointed_layers": 8
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": 1e-5,
|
||||||
|
"betas": [0.9, 0.999],
|
||||||
|
"eps": 1e-8
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": 0,
|
||||||
|
"warmup_max_lr": 1e-5,
|
||||||
|
"warmup_num_steps": 500
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"logging": {
|
||||||
|
"steps": 200
|
||||||
|
},
|
||||||
|
"checkpoint": {
|
||||||
|
"steps": 1000,
|
||||||
|
"save": true
|
||||||
|
}
|
||||||
|
}
|
23
main.py
Normal file
23
main.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
from configs import Config
|
||||||
|
from model import ModelManager
|
||||||
|
from dataset import DatasetManager
|
||||||
|
from trainer import TrainerManager
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Config.from_yaml("config.yaml")
|
||||||
|
|
||||||
|
model_manager = ModelManager(config)
|
||||||
|
model, tokenizer = model_manager.load_model_and_tokenizer()
|
||||||
|
|
||||||
|
dataset_manager = DatasetManager(config, tokenizer)
|
||||||
|
train_dataset, eval_dataset = dataset_manager.load_data()
|
||||||
|
|
||||||
|
trainer_manager = TrainerManager(config, model, tokenizer, train_dataset, eval_dataset)
|
||||||
|
trainer = trainer_manager.create_trainer()
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model(config.output_dir)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
44
model.py
Normal file
44
model.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
|
||||||
|
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(self):
|
||||||
|
model_kwargs = {}
|
||||||
|
if self.config.quantization_enable:
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=self.config.load_in_4bit,
|
||||||
|
bnb_4bit_compute_dtype=getattr(torch, self.config.bnb_4bit_compute_dtype),
|
||||||
|
bnb_4bit_quant_type=self.config.bnb_4bit_quant_type,
|
||||||
|
bnb_4bit_use_double_quant=self.config.bnb_4bit_use_double_quant,
|
||||||
|
)
|
||||||
|
model_kwargs["quantization_config"] = bnb_config
|
||||||
|
print("✅ 使用量化加载模型")
|
||||||
|
else:
|
||||||
|
print("🚀 使用全精度加载模型")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.config.model_name_or_path,
|
||||||
|
device_map="auto",
|
||||||
|
**model_kwargs
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.config.model_name_or_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
if self.config.quantization_enable:
|
||||||
|
model = prepare_model_for_kbit_training(model)
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=self.config.lora_r,
|
||||||
|
lora_alpha=self.config.lora_alpha,
|
||||||
|
lora_dropout=self.config.lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
target_modules=self.config.target_modules,
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
return model, tokenizer
|
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
transformers==4.28.0
|
||||||
|
torch==2.0.0
|
||||||
|
deepspeed==0.9.0
|
||||||
|
datasets==2.10.1
|
||||||
|
peft==0.1.0
|
||||||
|
bitsandbytes==0.39.0
|
||||||
|
accelerate==0.18.0
|
||||||
|
scipy==1.10.0
|
||||||
|
yaml==6.0
|
47
trainer.py
Normal file
47
trainer.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
|
||||||
|
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
class TrainerManager:
|
||||||
|
def __init__(self, config, model, tokenizer, train_dataset, eval_dataset):
|
||||||
|
self.config = config
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
|
||||||
|
def create_trainer(self):
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=self.config.output_dir,
|
||||||
|
per_device_train_batch_size=self.config.per_device_train_batch_size,
|
||||||
|
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
||||||
|
num_train_epochs=self.config.num_train_epochs,
|
||||||
|
learning_rate=self.config.learning_rate,
|
||||||
|
lr_scheduler_type=self.config.lr_scheduler_type,
|
||||||
|
warmup_steps=self.config.warmup_steps,
|
||||||
|
logging_steps=self.config.logging_steps,
|
||||||
|
save_steps=self.config.save_steps,
|
||||||
|
evaluation_strategy=self.config.evaluation_strategy,
|
||||||
|
eval_steps=self.config.eval_steps,
|
||||||
|
save_strategy=self.config.save_strategy,
|
||||||
|
save_total_limit=self.config.save_total_limit,
|
||||||
|
bf16=True,
|
||||||
|
report_to="none",
|
||||||
|
remove_unused_columns=False,
|
||||||
|
deepspeed=self.config.deepspeed,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="eval_loss",
|
||||||
|
greater_is_better=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
collator = DataCollatorForSeq2Seq(self.tokenizer, model=self.model, padding=True)
|
||||||
|
from callbacks import build_callbacks
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=self.model,
|
||||||
|
args=args,
|
||||||
|
train_dataset=self.train_dataset,
|
||||||
|
eval_dataset=self.eval_dataset,
|
||||||
|
data_collator=collator,
|
||||||
|
callbacks=build_callbacks(self.config),
|
||||||
|
)
|
||||||
|
return trainer
|
Loading…
Reference in New Issue
Block a user