From 5731cdc0cf5805966e4a1f7ecdca9b219ecbc40d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 12 Mar 2025 20:44:02 +0000 Subject: [PATCH] fixing sample packing --- src/axolotl/core/trainers/base.py | 117 ++++++++++++------------ src/axolotl/utils/collators/batching.py | 11 +++ src/axolotl/utils/samplers/multipack.py | 4 +- 3 files changed, 69 insertions(+), 63 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 83106868c..4d10e206b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -398,34 +398,60 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) + def _create_multipack_sampler(self, base_sampler): + """Helper method to create a MultipackBatchSampler""" + if self.args.multipack_real_batches: + batch_size = self.args.per_device_train_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size + ) + batch_max_len = train_batch_size * self.args.max_seq_length + + return MultipackBatchSampler( + base_sampler, + lengths=get_dataset_lengths(self.train_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, + drop_last=True, + ) + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and not self.args.pretraining: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_train_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - train_batch_size = ( - self.state.train_batch_size or self.args.per_device_train_batch_size - ) - batch_max_len = train_batch_size * self.args.max_seq_length + # Handle sequence parallelism + if self.args.sequence_parallel_size > 1: + num_sp_groups = self.args.world_size // self.args.sequence_parallel_size + sp_group_id = dist.get_rank() // self.args.sequence_parallel_size - if self.args.curriculum_sampling: - sampler = SequentialSampler(self.train_dataset) - else: - sampler = RandomSampler(self.train_dataset) - - return MultipackBatchSampler( - sampler, - lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, + # Create base sampler for SP groups + base_sampler = torch.utils.data.distributed.DistributedSampler( + self.train_dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + seed=self.args.seed if not self.args.curriculum_sampling else None, + shuffle=not self.args.curriculum_sampling, drop_last=True, ) + # Apply multipack wrapper if needed + if self.args.sample_packing and not self.args.pretraining: + return self._create_multipack_sampler(base_sampler) + + return base_sampler + + # Handle non-SP mode + if self.args.sample_packing and not self.args.pretraining: + sampler = ( + SequentialSampler(self.train_dataset) + if self.args.curriculum_sampling + else RandomSampler(self.train_dataset) + ) + return self._create_multipack_sampler(sampler) + if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) @@ -489,13 +515,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.dataloader_prefetch_factor: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - # Handle sequence parallelism (takes precedence over other sampling methods) - if self.args.sequence_parallel_size > 1: - return self._get_sequence_parallel_dataloader( - train_dataset, dataloader_params - ) - - # Handle other sampler cases + # Use the same sampling logic for all modes, including sequence parallelism if not isinstance(train_dataset, torch.utils.data.IterableDataset): sampler = self._get_train_sampler() @@ -510,43 +530,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["worker_init_fn"] = seed_worker - # Create and possibly prepare the dataloader + # Create dataloader dataloader = DataLoader(train_dataset, **dataloader_params) # Sample packing with accelerator preparation if self.args.sample_packing and not self.args.pretraining: self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader(dataloader) - return self.accelerator.prepare(dataloader) + # Don't prepare dataloader for sequence parallelism + # We use a distributed sampler in this case + if self.args.sequence_parallel_size > 1: + return dataloader - def _get_sequence_parallel_dataloader(self, dataset, dataloader_params): - """Create a dataloader with sequence-parallel-aware sampling.""" - # Calculate SP group information - num_sp_groups = self.args.world_size // self.args.sequence_parallel_size - global_rank = dist.get_rank() - sp_group_id = global_rank // self.args.sequence_parallel_size - - # Create SP-group-aware sampler - if isinstance(dataset, torch.utils.data.IterableDataset): - sampler = None - else: - # Use different seeds for different SP groups to ensure they get different samples - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed, - drop_last=self.args.dataloader_drop_last, - ) - - # Create dataloader without accelerator preparation - return DataLoader( - dataset, - sampler=sampler, - worker_init_fn=seed_worker, - **dataloader_params, - ) + # Prepare dataloader for accelerate distributed training + return self.accelerator.prepare_data_loader(dataloader) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if self.args.sample_packing and self.args.eval_sample_packing is False: diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 885338651..53c039479 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -261,6 +261,12 @@ class DataCollatorForSeq2Seq: # # Replace invalid labels with -100 (ignore index) # batch["labels"][invalid_mask] = -100 + if key == "attention_mask": + logger.info( + f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + f"Attention mask: {batch['attention_mask']}, " + ) + # Handle position_ids if present if "position_ids" in batch: pos_ids = batch["position_ids"] @@ -281,6 +287,11 @@ class DataCollatorForSeq2Seq: batch["position_ids"], start_idx ) + logger.info( + f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: " + f"Position IDs: {batch['position_ids']}, " + ) + # if dist.get_rank() == 0: # import ipdb; ipdb.set_trace() # dist.barrier() diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 6119dff30..41095152e 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -104,9 +104,7 @@ def allocate( class MultipackBatchSampler(BatchSampler): - """ - Batch Sampler class for multipack - """ + """Batch sampler class for multipack""" def __init__( self,