29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
|
|
from transformers import TrainerCallback, EarlyStoppingCallback
|
|
import os
|
|
|
|
class SaveBestModelCallback(TrainerCallback):
|
|
def __init__(self):
|
|
self.best_metric = None
|
|
|
|
def on_evaluate(self, args, state, control, **kwargs):
|
|
metrics = kwargs.get("metrics", {})
|
|
eval_loss = metrics.get("eval_loss")
|
|
if eval_loss is None:
|
|
return
|
|
if self.best_metric is None or eval_loss < self.best_metric:
|
|
print(f"🌟 Best model updated: {self.best_metric} -> {eval_loss}")
|
|
self.best_metric = eval_loss
|
|
best_model_path = os.path.join(args.output_dir, "best_model")
|
|
kwargs["model"].save_pretrained(best_model_path)
|
|
kwargs["tokenizer"].save_pretrained(best_model_path)
|
|
|
|
def build_callbacks(config):
|
|
return [
|
|
EarlyStoppingCallback(
|
|
early_stopping_patience=config.early_stopping_patience,
|
|
early_stopping_threshold=config.early_stopping_threshold,
|
|
),
|
|
SaveBestModelCallback(),
|
|
]
|