22 lines
970 B
Python
22 lines
970 B
Python
|
|
from datasets import load_dataset
|
|
|
|
class DatasetManager:
|
|
def __init__(self, config, tokenizer):
|
|
self.config = config
|
|
self.tokenizer = tokenizer
|
|
|
|
def preprocess_function(self, examples, max_length):
|
|
return self.tokenizer(examples["input"], max_length=max_length, truncation=True)
|
|
|
|
def load_data(self):
|
|
dataset = load_dataset("json", data_files=self.config.dataset_path, split="train")
|
|
dataset = dataset.shuffle(seed=42)
|
|
split_idx = int(len(dataset) * (1.0 - self.config.validation_split_ratio))
|
|
train_dataset = dataset.select(range(split_idx))
|
|
eval_dataset = dataset.select(range(split_idx, len(dataset)))
|
|
|
|
train_dataset = train_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
|
|
eval_dataset = eval_dataset.map(lambda x: self.preprocess_function(x, self.config.max_seq_length), batched=True)
|
|
return train_dataset, eval_dataset
|