From 64c203cdef90f2f681e1a49cfb067dddea609417 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 17 Mar 2025 03:08:39 +0000 Subject: [PATCH] sampler / dataloader refactor --- src/axolotl/core/trainers/base.py | 308 ++++++++++++++---------------- 1 file changed, 144 insertions(+), 164 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 793dcc6a0..70f3a86ea 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,7 +7,7 @@ import logging import os from collections import defaultdict from functools import wraps -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal import datasets import torch @@ -435,20 +435,25 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): drop_last=True, ) - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + def _create_sp_sampler(self, dataset, shuffle=True, is_eval=False): + """Create a sampler for sequence parallelism""" + num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree + sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree + + return torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed if shuffle else None, + shuffle=shuffle, + drop_last=not is_eval, + ) + + def _get_train_sampler(self) -> torch.utils.data.Sampler | None: # Handle sequence parallelism if self.args.sequence_parallel_degree > 1: - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree - - # Create base sampler for SP groups - base_sampler = torch.utils.data.distributed.DistributedSampler( - self.train_dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed if not self.args.curriculum_sampling else None, - shuffle=not self.args.curriculum_sampling, - drop_last=True, + base_sampler = self._create_sp_sampler( + self.train_dataset, shuffle=not self.args.curriculum_sampling ) # Apply multipack wrapper if needed @@ -458,16 +463,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataset=self.train_dataset, group_size=self.args.sample_packing_group_size, ) - return base_sampler + # Regular training sampler logic if self.args.sample_packing and not self.args.pretraining: base_sampler = ( SequentialSampler(self.train_dataset) if self.args.curriculum_sampling else RandomSampler(self.train_dataset) ) - return self._create_multipack_sampler( base_sampler=base_sampler, dataset=self.train_dataset, @@ -480,94 +484,76 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super()._get_train_sampler() def _get_eval_sampler( - self, eval_dataset: Optional[Dataset] = None - ) -> Optional[torch.utils.data.Sampler]: + self, eval_dataset: Dataset | None = None + ) -> torch.utils.data.Sampler | None: """Get evaluation sampler""" eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - # Handle sequence parallelism - if self.args.sequence_parallel_degree > 1: - # Create sampler for SP groups - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree - - # Create distributed sampler for the SP group - base_sampler = torch.utils.data.distributed.DistributedSampler( - eval_dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - shuffle=False, - drop_last=False, - ) - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - group_size = ( - self.args.eval_packing_group_size - if hasattr(self.args, "eval_packing_group_size") - else self.args.sample_packing_group_size - ) - - return self._create_multipack_sampler( - base_sampler=base_sampler, - dataset=eval_dataset, - group_size=group_size, - ) - - return base_sampler - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - base_sampler = SequentialSampler(eval_dataset) - group_size = ( + # Get the appropriate group size for sample packing + def get_pack_group_size(): + return ( self.args.eval_packing_group_size if hasattr(self.args, "eval_packing_group_size") else self.args.sample_packing_group_size ) + # Handle sequence parallelism + if self.args.sequence_parallel_degree > 1: + base_sampler = self._create_sp_sampler( + eval_dataset, shuffle=False, is_eval=True + ) + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=eval_dataset, + group_size=get_pack_group_size(), + ) + return base_sampler + + # Regular evaluation sampler logic + if self.args.sample_packing and self.args.eval_sample_packing is not False: + base_sampler = SequentialSampler(eval_dataset) return self._create_multipack_sampler( base_sampler=base_sampler, dataset=eval_dataset, - group_size=group_size, + group_size=get_pack_group_size(), ) return super()._get_eval_sampler(eval_dataset) - def get_train_dataloader(self) -> DataLoader: - """Get dataloader for training""" - train_dataset = self.train_dataset - data_collator = self.data_collator + def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): + """Create common dataloader parameters for train or eval.""" + batch_size = custom_batch_size or ( + self.args.eval_batch_size if is_eval else self._train_batch_size + ) - # Handle dataset preprocessing - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - if ( - self.args.sample_packing - and not self.args.pretraining - and "length" in train_dataset.features - ): - train_dataset = train_dataset.remove_columns(["length"]) - if not (self.args.sample_packing and not self.args.pretraining): - train_dataset = self._remove_unused_columns( - train_dataset, description="training" - ) - else: - data_collator = self._get_collator_with_removed_columns( - data_collator, description="training" - ) - - # Build common dataloader parameters - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, + params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, } + # Add persistent workers only for training + if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): + params["persistent_workers"] = self.args.dataloader_persistent_workers + + # Add prefetch factor if specified if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + params["prefetch_factor"] = self.args.dataloader_prefetch_factor - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - sampler = self._get_train_sampler() + return params + def _prepare_dataloader( + self, dataset, sampler, is_eval=False, custom_batch_size=None + ): + """Prepare a dataloader with the given dataset and sampler.""" + # Get base parameters + dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) + + # Add sampler configuration + if not isinstance(dataset, torch.utils.data.IterableDataset): if isinstance(sampler, BatchSampler): # batch_size and batch_sampler are mutually exclusive dataloader_params["batch_sampler"] = sampler @@ -576,114 +562,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["sampler"] = sampler dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker + if not is_eval: + dataloader_params["worker_init_fn"] = seed_worker - dataloader = DataLoader(train_dataset, **dataloader_params) - if self.args.sample_packing and not self.args.pretraining: + # Create the dataloader + dataloader = DataLoader(dataset, **dataloader_params) + + if self.args.sample_packing and ( + (not is_eval and not self.args.pretraining) + or (is_eval and self.args.eval_sample_packing is not False) + ): self.accelerator.even_batches = False - # Don't prepare dataloader for sequence parallelism - # We use a distributed sampler in this case + # Return unprepared dataloader if using sequence parallelism if self.args.sequence_parallel_degree > 1: return dataloader + # Otherwise prepare with accelerator return self.accelerator.prepare_data_loader(dataloader) - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + def get_train_dataloader(self) -> DataLoader: + """Get dataloader for training""" + train_dataset = self.train_dataset + data_collator = self.data_collator # type: ignore + + # Handle dataset preprocessing + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + if self.args.sample_packing and not self.args.pretraining: + train_dataset = train_dataset.remove_columns(["length"]) + if not (self.args.sample_packing and not self.args.pretraining): + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + data_collator, + description="training", + ) + + # Get sampler and create dataloader + sampler = self._get_train_sampler() + return self._prepare_dataloader(train_dataset, sampler, is_eval=False) + + def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: """Get dataloader for evaluation""" + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + # Handle special case: sample packing is enabled but eval_sample_packing is False if self.args.sample_packing and self.args.eval_sample_packing is False: self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator ) - if eval_dataset and "length" in eval_dataset.features: - eval_dataset = eval_dataset.remove_columns(["length"]) + eval_dataset = eval_dataset.remove_columns(["length"]) dataloader = super().get_eval_dataloader(eval_dataset) self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.train_data_collator ) return dataloader - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - eval_sampler = self._get_eval_sampler(eval_dataset) - - # Only remove length column if it exists - if "length" in eval_dataset.features: - eval_dataset = eval_dataset.remove_columns(["length"]) - - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = ( - self.args.dataloader_prefetch_factor - ) - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - dataloader = DataLoader(eval_dataset, **dataloader_params) - - # Don't prepare dataloader for sequence parallelism - # We use a distributed sampler in this case - if self.args.sequence_parallel_degree > 1: - return dataloader - - return self.accelerator.prepare_data_loader(dataloader) - if self.args.sequence_parallel_degree > 1: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - data_collator = ( + # Handle sample packing or sequence parallelism + if ( + self.args.sample_packing + and self.args.eval_sample_packing is not False + or self.args.sequence_parallel_degree > 1 + ): + # Get appropriate data collator + self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator - if self.eval_data_collator + if hasattr(self, "eval_data_collator") and self.eval_data_collator else self.data_collator ) + eval_dataset = eval_dataset.remove_columns(["length"]) - # Handle dataset preprocessing as in the parent implementation - if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns( - eval_dataset, description="evaluation" - ) - else: - data_collator = self._get_collator_with_removed_columns( - data_collator, description="evaluation" - ) + # Handle dataset preprocessing for SP + if self.args.sequence_parallel_degree > 1: + if is_datasets_available() and isinstance( + eval_dataset, datasets.Dataset + ): + eval_dataset = self._remove_unused_columns( + eval_dataset, description="evaluation" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + self.data_collator, description="evaluation" + ) - # Build dataloader parameters - dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } + # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise + batch_size = ( + self.args.eval_batch_size + if self.args.sample_packing + else self.args.per_device_eval_batch_size + ) + sampler = self._get_eval_sampler(eval_dataset) + dataloader = self._prepare_dataloader( + eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size + ) - if not isinstance(eval_dataset, torch.utils.data.IterableDataset): - sampler = self._get_eval_sampler(eval_dataset) - dataloader_params["sampler"] = sampler - - # Don't prepare dataloader for sequence parallelism - # We use a distributed sampler in this case - return DataLoader(eval_dataset, **dataloader_params) + return dataloader return super().get_eval_dataloader(eval_dataset) def _get_bench_sampler( self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: + ) -> torch.utils.data.Sampler | None: if self.args.world_size <= 1: return SequentialSampler(bench_dataset) return None @@ -920,15 +902,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return res - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. Args: - logs (`Dict[str, float]`): - The values to log. - start_time (`Optional[float]`): - The start of training. + logs: The values to log. + start_time: The start of training. """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" @@ -940,7 +920,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super().log(logs, start_time) def store_metrics( - self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" ) -> None: for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) @@ -1043,7 +1023,7 @@ class ReLoRATrainer(AxolotlTrainer): def create_scheduler( self, num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, ): optimizer = self.optimizer if optimizer is None else optimizer lr_scheduler = super().create_scheduler(num_training_steps, optimizer)