sampler / dataloader refactor

This commit is contained in:
Dan Saunders
2025-03-17 03:08:39 +00:00
parent 7d7042f602
commit 64c203cdef

View File

@@ -7,7 +7,7 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal
import datasets
import torch
@@ -435,20 +435,25 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
drop_last=True,
)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False):
"""Create a sampler for sequence parallelism"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self) -> torch.utils.data.Sampler | None:
# Handle sequence parallelism
if self.args.sequence_parallel_degree > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create base sampler for SP groups
base_sampler = torch.utils.data.distributed.DistributedSampler(
self.train_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if not self.args.curriculum_sampling else None,
shuffle=not self.args.curriculum_sampling,
drop_last=True,
base_sampler = self._create_sp_sampler(
self.train_dataset, shuffle=not self.args.curriculum_sampling
)
# Apply multipack wrapper if needed
@@ -458,16 +463,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size,
)
return base_sampler
# Regular training sampler logic
if self.args.sample_packing and not self.args.pretraining:
base_sampler = (
SequentialSampler(self.train_dataset)
if self.args.curriculum_sampling
else RandomSampler(self.train_dataset)
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
@@ -480,94 +484,76 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Optional[Dataset] = None
) -> Optional[torch.utils.data.Sampler]:
self, eval_dataset: Dataset | None = None
) -> torch.utils.data.Sampler | None:
"""Get evaluation sampler"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sequence parallelism
if self.args.sequence_parallel_degree > 1:
# Create sampler for SP groups
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create distributed sampler for the SP group
base_sampler = torch.utils.data.distributed.DistributedSampler(
eval_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
shuffle=False,
drop_last=False,
)
if self.args.sample_packing and self.args.eval_sample_packing is not False:
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return base_sampler
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
group_size = (
# Get the appropriate group size for sample packing
def get_pack_group_size():
return (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
# Handle sequence parallelism
if self.args.sequence_parallel_degree > 1:
base_sampler = self._create_sp_sampler(
eval_dataset, shuffle=False, is_eval=True
)
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=get_pack_group_size(),
)
return base_sampler
# Regular evaluation sampler logic
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
group_size=get_pack_group_size(),
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
# Handle dataset preprocessing
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if (
self.args.sample_packing
and not self.args.pretraining
and "length" in train_dataset.features
):
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
# Build common dataloader parameters
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
sampler = self._get_train_sampler()
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
@@ -576,114 +562,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
dataloader = DataLoader(train_dataset, **dataloader_params)
if self.args.sample_packing and not self.args.pretraining:
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
# Return unprepared dataloader if using sequence parallelism
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset and "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"])
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
# Only remove length column if it exists
if "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
dataloader = DataLoader(eval_dataset, **dataloader_params)
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_degree > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_degree > 1:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
data_collator = (
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing as in the parent implementation
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="evaluation"
)
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if is_datasets_available() and isinstance(
eval_dataset, datasets.Dataset
):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Build dataloader parameters
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
sampler = self._get_eval_sampler(eval_dataset)
dataloader_params["sampler"] = sampler
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
return DataLoader(eval_dataset, **dataloader_params)
return dataloader
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
) -> torch.utils.data.Sampler | None:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
@@ -920,15 +902,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
logs: The values to log.
start_time: The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -940,7 +920,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
@@ -1043,7 +1023,7 @@ class ReLoRATrainer(AxolotlTrainer):
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: torch.optim.Optimizer | None = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)