qlora-train/configs.py
2025-04-28 08:50:31 +00:00

42 lines
1010 B
Python

import yaml
from dataclasses import dataclass
@dataclass
class Config:
model_name_or_path: str
output_dir: str
dataset_path: str
validation_split_ratio: float
per_device_train_batch_size: int
gradient_accumulation_steps: int
num_train_epochs: int
learning_rate: float
lr_scheduler_type: str
warmup_steps: int
lora_r: int
lora_alpha: int
lora_dropout: float
target_modules: list
quantization_enable: bool
load_in_4bit: bool
bnb_4bit_compute_dtype: str
bnb_4bit_quant_type: str
bnb_4bit_use_double_quant: bool
logging_steps: int
save_steps: int
evaluation_strategy: str
eval_steps: int
save_strategy: str
save_total_limit: int
max_seq_length: int
early_stopping_patience: int
early_stopping_threshold: float
deepspeed: str
@classmethod
def from_yaml(cls, file_path):
with open(file_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)