42 lines
1010 B
Python
42 lines
1010 B
Python
|
|
import yaml
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class Config:
|
|
model_name_or_path: str
|
|
output_dir: str
|
|
dataset_path: str
|
|
validation_split_ratio: float
|
|
per_device_train_batch_size: int
|
|
gradient_accumulation_steps: int
|
|
num_train_epochs: int
|
|
learning_rate: float
|
|
lr_scheduler_type: str
|
|
warmup_steps: int
|
|
lora_r: int
|
|
lora_alpha: int
|
|
lora_dropout: float
|
|
target_modules: list
|
|
quantization_enable: bool
|
|
load_in_4bit: bool
|
|
bnb_4bit_compute_dtype: str
|
|
bnb_4bit_quant_type: str
|
|
bnb_4bit_use_double_quant: bool
|
|
logging_steps: int
|
|
save_steps: int
|
|
evaluation_strategy: str
|
|
eval_steps: int
|
|
save_strategy: str
|
|
save_total_limit: int
|
|
max_seq_length: int
|
|
early_stopping_patience: int
|
|
early_stopping_threshold: float
|
|
deepspeed: str
|
|
|
|
@classmethod
|
|
def from_yaml(cls, file_path):
|
|
with open(file_path, "r") as f:
|
|
config_dict = yaml.safe_load(f)
|
|
return cls(**config_dict)
|