sampler / dataloader refactor

This commit is contained in:
Dan Saunders
2025-03-17 03:08:39 +00:00
parent 7d7042f602
commit 64c203cdef

View File

@@ -7,7 +7,7 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from typing import Any, Dict, Literal, Optional from typing import Any, Literal
import datasets import datasets
import torch import torch
@@ -435,20 +435,25 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
drop_last=True, drop_last=True,
) )
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False):
"""Create a sampler for sequence parallelism"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self) -> torch.utils.data.Sampler | None:
# Handle sequence parallelism # Handle sequence parallelism
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree base_sampler = self._create_sp_sampler(
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree self.train_dataset, shuffle=not self.args.curriculum_sampling
# Create base sampler for SP groups
base_sampler = torch.utils.data.distributed.DistributedSampler(
self.train_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if not self.args.curriculum_sampling else None,
shuffle=not self.args.curriculum_sampling,
drop_last=True,
) )
# Apply multipack wrapper if needed # Apply multipack wrapper if needed
@@ -458,16 +463,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataset=self.train_dataset, dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size, group_size=self.args.sample_packing_group_size,
) )
return base_sampler return base_sampler
# Regular training sampler logic
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
base_sampler = ( base_sampler = (
SequentialSampler(self.train_dataset) SequentialSampler(self.train_dataset)
if self.args.curriculum_sampling if self.args.curriculum_sampling
else RandomSampler(self.train_dataset) else RandomSampler(self.train_dataset)
) )
return self._create_multipack_sampler( return self._create_multipack_sampler(
base_sampler=base_sampler, base_sampler=base_sampler,
dataset=self.train_dataset, dataset=self.train_dataset,
@@ -480,94 +484,76 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler( def _get_eval_sampler(
self, eval_dataset: Optional[Dataset] = None self, eval_dataset: Dataset | None = None
) -> Optional[torch.utils.data.Sampler]: ) -> torch.utils.data.Sampler | None:
"""Get evaluation sampler""" """Get evaluation sampler"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sequence parallelism # Get the appropriate group size for sample packing
if self.args.sequence_parallel_degree > 1: def get_pack_group_size():
# Create sampler for SP groups return (
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create distributed sampler for the SP group
base_sampler = torch.utils.data.distributed.DistributedSampler(
eval_dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
shuffle=False,
drop_last=False,
)
if self.args.sample_packing and self.args.eval_sample_packing is not False:
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return base_sampler
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
group_size = (
self.args.eval_packing_group_size self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size") if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size else self.args.sample_packing_group_size
) )
# Handle sequence parallelism
if self.args.sequence_parallel_degree > 1:
base_sampler = self._create_sp_sampler(
eval_dataset, shuffle=False, is_eval=True
)
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=get_pack_group_size(),
)
return base_sampler
# Regular evaluation sampler logic
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
return self._create_multipack_sampler( return self._create_multipack_sampler(
base_sampler=base_sampler, base_sampler=base_sampler,
dataset=eval_dataset, dataset=eval_dataset,
group_size=group_size, group_size=get_pack_group_size(),
) )
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader: def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Get dataloader for training""" """Create common dataloader parameters for train or eval."""
train_dataset = self.train_dataset batch_size = custom_batch_size or (
data_collator = self.data_collator self.args.eval_batch_size if is_eval else self._train_batch_size
)
# Handle dataset preprocessing params = {
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): "batch_size": batch_size,
if ( "collate_fn": self.data_collator,
self.args.sample_packing
and not self.args.pretraining
and "length" in train_dataset.features
):
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
# Build common dataloader parameters
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
} }
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(train_dataset, torch.utils.data.IterableDataset): return params
sampler = self._get_train_sampler()
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive # batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler dataloader_params["batch_sampler"] = sampler
@@ -576,114 +562,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["sampler"] = sampler dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
dataloader = DataLoader(train_dataset, **dataloader_params) # Create the dataloader
if self.args.sample_packing and not self.args.pretraining: dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False self.accelerator.even_batches = False
# Don't prepare dataloader for sequence parallelism # Return unprepared dataloader if using sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
return dataloader return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader) return self.accelerator.prepare_data_loader(dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation""" """Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False: if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator self.eval_data_collator
) )
if eval_dataset and "length" in eval_dataset.features: eval_dataset = eval_dataset.remove_columns(["length"])
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset) dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator self.train_data_collator
) )
return dataloader return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset) # Handle sample packing or sequence parallelism
if (
# Only remove length column if it exists self.args.sample_packing
if "length" in eval_dataset.features: and self.args.eval_sample_packing is not False
eval_dataset = eval_dataset.remove_columns(["length"]) or self.args.sequence_parallel_degree > 1
):
data_collator = self.data_collator # Get appropriate data collator
dataloader_params = { self.data_collator = ( # pylint: disable=attribute-defined-outside-init
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
dataloader = DataLoader(eval_dataset, **dataloader_params)
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_degree > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_degree > 1:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
data_collator = (
self.eval_data_collator self.eval_data_collator
if self.eval_data_collator if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator else self.data_collator
) )
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing as in the parent implementation # Handle dataset preprocessing for SP
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): if self.args.sequence_parallel_degree > 1:
eval_dataset = self._remove_unused_columns( if is_datasets_available() and isinstance(
eval_dataset, description="evaluation" eval_dataset, datasets.Dataset
) ):
else: eval_dataset = self._remove_unused_columns(
data_collator = self._get_collator_with_removed_columns( eval_dataset, description="evaluation"
data_collator, description="evaluation" )
) else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Build dataloader parameters # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
dataloader_params = { batch_size = (
"batch_size": self.args.per_device_eval_batch_size, self.args.eval_batch_size
"collate_fn": data_collator, if self.args.sample_packing
"num_workers": self.args.dataloader_num_workers, else self.args.per_device_eval_batch_size
"pin_memory": self.args.dataloader_pin_memory, )
} sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): return dataloader
sampler = self._get_eval_sampler(eval_dataset)
dataloader_params["sampler"] = sampler
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
return DataLoader(eval_dataset, **dataloader_params)
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler( def _get_bench_sampler(
self, bench_dataset: Dataset self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]: ) -> torch.utils.data.Sampler | None:
if self.args.world_size <= 1: if self.args.world_size <= 1:
return SequentialSampler(bench_dataset) return SequentialSampler(bench_dataset)
return None return None
@@ -920,15 +902,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return res return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
""" """
Log `logs` on the various objects watching training, including stored metrics. Log `logs` on the various objects watching training, including stored metrics.
Args: Args:
logs (`Dict[str, float]`): logs: The values to log.
The values to log. start_time: The start of training.
start_time (`Optional[float]`):
The start of training.
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
@@ -940,7 +920,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super().log(logs, start_time) return super().log(logs, start_time)
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None: ) -> None:
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
@@ -1043,7 +1023,7 @@ class ReLoRATrainer(AxolotlTrainer):
def create_scheduler( def create_scheduler(
self, self,
num_training_steps: int, num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: torch.optim.Optimizer | None = None,
): ):
optimizer = self.optimizer if optimizer is None else optimizer optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer) lr_scheduler = super().create_scheduler(num_training_steps, optimizer)