diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 891f154ec..0b969c3aa 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -413,13 +413,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.curriculum_sampling: sampler = SequentialSampler(self.train_dataset) else: - generator = None - if self.args.sequence_parallel_size > 1: - generator = torch.Generator() - generator.manual_seed(self.args.seed) - sampler = RandomSampler( - self.train_dataset, generator=generator - ) + sampler = RandomSampler(self.train_dataset) return MultipackBatchSampler( sampler, @@ -434,13 +428,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) - sampler = super()._get_train_sampler() - if self.args.sequence_parallel_size > 1: - generator = torch.Generator() - generator.manual_seed(self.args.seed) - sampler.generator = generator - - return sampler + return super()._get_train_sampler() def _get_eval_sampler( self, eval_dataset: Dataset @@ -498,28 +486,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self.accelerator.even_batches = False if self.args.sequence_parallel_size > 1: return DataLoader(train_dataset, **dataloader_params) - else: - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - else: - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) + ) - if self.args.sequence_parallel_size > 1: - return DataLoader(train_dataset, **dataloader_params) - else: - return self.accelerator.prepare( - DataLoader(train_dataset, **dataloader_params) - ) + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + data_collator = self._get_collator_with_removed_columns( + data_collator, description="training" + ) + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if self.args.sequence_parallel_size > 1: + return DataLoader(train_dataset, **dataloader_params) + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if self.args.sample_packing and self.args.eval_sample_packing is False: @@ -842,7 +832,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], - num_items_in_batch: int = None, + num_items_in_batch: int | None = None, ) -> torch.Tensor: """ Perform a training step on a batch of inputs. diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index ee4679230..eaf47c751 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -218,7 +218,9 @@ class DataCollatorForSeq2Seq: f"Non-padding tokens: {non_pad_tokens_total}" ) logger.info(f"GPU {self.rank} token IDs: {batch['input_ids']}") - logger.info(f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}") + logger.info( + f"GPU {self.rank} start_ids:end_idx: {start_idx}:{end_idx}" + ) batch[key] = batch[key][:, start_idx:end_idx] diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sequence_parallelism.py index a3483122b..600debbbb 100644 --- a/tests/e2e/patched/test_sequence_parallelism.py +++ b/tests/e2e/patched/test_sequence_parallelism.py @@ -11,10 +11,7 @@ from accelerate.state import PartialState ring_flash_attn_mock = MagicMock() with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - from axolotl.utils.collators.sequence_parallel import ( - adjust_position_ids_for_slice, - check_for_boundary_splits, - ) + from axolotl.utils.collators.batching import adjust_position_ids_for_slice # Create a fixture for PartialState @@ -52,37 +49,6 @@ class TestSequenceParallelHelpers: assert torch.all(adjusted[0] == expected_first_seq) assert torch.all(adjusted[1] == expected_second_seq) - def test_check_for_boundary_splits(self): - """Test detection of boundaries near slice edges.""" - # Boundaries at positions 10, 25, 40 - boundaries = [10, 25, 40] - - # Test case where two boundaries are near edges (one at start, one at end) - problems = check_for_boundary_splits(boundaries, slice_start=8, slice_end=30) - assert ( - len(problems) == 2 - ) # Both boundary at 10 (near start) and 25 (near end) are problems - - # Check first problem - boundary near start - assert problems[0][0] == 10 # The boundary position - assert problems[0][1] == "start" # Type of issue - assert problems[0][2] == 2 # Distance from start - - # Check second problem - boundary near end - assert problems[1][0] == 25 # The boundary position - assert problems[1][1] == "end" # Type of issue - assert problems[1][2] == 5 # Distance from end - - # Test case with only one problem at the end - problems = check_for_boundary_splits(boundaries, slice_start=15, slice_end=27) - assert len(problems) == 1 # Only boundary at 25 is near the end - assert problems[0][0] == 25 # The boundary - assert problems[0][1] == "end" # Type of issue - - # Test case with no problems - problems = check_for_boundary_splits(boundaries, slice_start=12, slice_end=20) - assert len(problems) == 0 - class TestRingAttention: """Tests for the ring attention functionality."""