qlora-train/dataset.py
2025-04-28 08:50:31 +00:00

22 lines
970 B
Python

from datasets import load_dataset
class DatasetManager:
def __init__(self, config, tokenizer):
self.config = config
self.tokenizer = tokenizer
def preprocess_function(self, examples, max_length):
return self.tokenizer(examples["input"], max_length=max_length, truncation=True)
def load_data(self):
dataset = load_dataset("json", data_files=self.config.dataset_path, split="train")
dataset = dataset.shuffle(seed=42)
split_idx = int(len(dataset) * (1.0 - self.config.validation_split_ratio))
train_dataset = dataset.select(range(split_idx))
eval_dataset = dataset.select(range(split_idx, len(dataset)))
train_dataset = train_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
eval_dataset = eval_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
return train_dataset, eval_dataset