qlora-train/callbacks.py
2025-04-28 08:50:31 +00:00

29 lines
1.0 KiB
Python

from transformers import TrainerCallback, EarlyStoppingCallback
import os
class SaveBestModelCallback(TrainerCallback):
def __init__(self):
self.best_metric = None
def on_evaluate(self, args, state, control, **kwargs):
metrics = kwargs.get("metrics", {})
eval_loss = metrics.get("eval_loss")
if eval_loss is None:
return
if self.best_metric is None or eval_loss < self.best_metric:
print(f"🌟 Best model updated: {self.best_metric} -> {eval_loss}")
self.best_metric = eval_loss
best_model_path = os.path.join(args.output_dir, "best_model")
kwargs["model"].save_pretrained(best_model_path)
kwargs["tokenizer"].save_pretrained(best_model_path)
def build_callbacks(config):
return [
EarlyStoppingCallback(
early_stopping_patience=config.early_stopping_patience,
early_stopping_threshold=config.early_stopping_threshold,
),
SaveBestModelCallback(),
]