From 4ff97bc9d439246fead7d8d7e32027079efcb448 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 13 Mar 2025 19:24:30 +0000 Subject: [PATCH] eval dataloader and sampler changes --- src/axolotl/core/trainers/base.py | 145 ++++++++++++++---- .../e2e/patched/test_sequence_parallelism.py | 81 +++++++--- 2 files changed, 173 insertions(+), 53 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 4d10e206b..63195a55a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -398,7 +398,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) return super()._wrap_model(model, training=training, dataloader=dataloader) - def _create_multipack_sampler(self, base_sampler): + def _create_multipack_sampler(self, base_sampler, dataset, group_size): """Helper method to create a MultipackBatchSampler""" if self.args.multipack_real_batches: batch_size = self.args.per_device_train_batch_size @@ -412,11 +412,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return MultipackBatchSampler( base_sampler, - lengths=get_dataset_lengths(self.train_dataset), + lengths=get_dataset_lengths(dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, batch_max_len=batch_max_len, batch_size=batch_size, - group_size=self.args.sample_packing_group_size, + group_size=group_size, bin_size=self.args.sample_packing_bin_size, drop_last=True, ) @@ -439,18 +439,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): # Apply multipack wrapper if needed if self.args.sample_packing and not self.args.pretraining: - return self._create_multipack_sampler(base_sampler) + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=self.train_dataset, + group_size=self.args.sample_packing_group_size, + ) return base_sampler - # Handle non-SP mode if self.args.sample_packing and not self.args.pretraining: - sampler = ( + base_sampler = ( SequentialSampler(self.train_dataset) if self.args.curriculum_sampling else RandomSampler(self.train_dataset) ) - return self._create_multipack_sampler(sampler) + + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=self.train_dataset, + group_size=self.args.sample_packing_group_size, + ) if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) @@ -458,27 +466,55 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return super()._get_train_sampler() def _get_eval_sampler( - self, eval_dataset: Dataset + self, eval_dataset: Optional[Dataset] = None ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_eval_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - batch_max_len = ( - self.args.per_device_eval_batch_size * self.args.max_seq_length - ) - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - lengths=get_dataset_lengths(self.eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, + """Get evaluation sampler""" + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + # Handle sequence parallelism + if self.args.sequence_parallel_size > 1: + # Create sampler for SP groups + num_sp_groups = self.args.world_size // self.args.sequence_parallel_size + sp_group_id = dist.get_rank() // self.args.sequence_parallel_size + + # Create distributed sampler for the SP group + base_sampler = torch.utils.data.distributed.DistributedSampler( + eval_dataset, + num_replicas=num_sp_groups, + rank=sp_group_id, + shuffle=False, + drop_last=False, ) + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + group_size = ( + self.args.eval_packing_group_size + if hasattr(self.args, "eval_packing_group_size") + else self.args.sample_packing_group_size + ) + + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=eval_dataset, + group_size=group_size, + ) + + return base_sampler + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + base_sampler = SequentialSampler(eval_dataset) + group_size = ( + self.args.eval_packing_group_size + if hasattr(self.args, "eval_packing_group_size") + else self.args.sample_packing_group_size + ) + + return self._create_multipack_sampler( + base_sampler=base_sampler, + dataset=eval_dataset, + group_size=group_size, + ) + return super()._get_eval_sampler(eval_dataset) def get_train_dataloader(self) -> DataLoader: @@ -546,25 +582,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): return self.accelerator.prepare_data_loader(dataloader) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """Get dataloader for evaluation""" if self.args.sample_packing and self.args.eval_sample_packing is False: self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator ) - if eval_dataset: + if eval_dataset and "length" in eval_dataset.features: eval_dataset = eval_dataset.remove_columns(["length"]) dataloader = super().get_eval_dataloader(eval_dataset) self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.train_data_collator ) - return dataloader + return dataloader if self.args.sample_packing and self.args.eval_sample_packing is not False: eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset ) eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) + + # Only remove length column if it exists + if "length" in eval_dataset.features: + eval_dataset = eval_dataset.remove_columns(["length"]) + data_collator = self.data_collator dataloader_params = { "batch_size": self.args.eval_batch_size, @@ -572,6 +613,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } + if self.args.dataloader_prefetch_factor: dataloader_params["prefetch_factor"] = ( self.args.dataloader_prefetch_factor @@ -585,9 +627,50 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) + + # Create dataloader + dataloader = DataLoader(eval_dataset, **dataloader_params) + + # Don't prepare dataloader for sequence parallelism + if self.args.sequence_parallel_size > 1: + return dataloader + + return self.accelerator.prepare_data_loader(dataloader) + if self.args.sequence_parallel_size > 1: + # We need to customize the default dataloader for sequence parallelism + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset ) + data_collator = ( + self.eval_data_collator + if self.eval_data_collator + else self.data_collator + ) + + # Handle dataset preprocessing as in the parent implementation + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns( + eval_dataset, description="evaluation" + ) + else: + data_collator = self._get_collator_with_removed_columns( + data_collator, description="evaluation" + ) + + # Build dataloader parameters + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + sampler = self._get_eval_sampler(eval_dataset) + dataloader_params["sampler"] = sampler + + # Create dataloader without accelerator preparation for sequence parallelism + return DataLoader(eval_dataset, **dataloader_params) return super().get_eval_dataloader(eval_dataset) diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sequence_parallelism.py index 600debbbb..126af595f 100644 --- a/tests/e2e/patched/test_sequence_parallelism.py +++ b/tests/e2e/patched/test_sequence_parallelism.py @@ -7,6 +7,8 @@ import pytest import torch from accelerate.state import PartialState +from axolotl.utils.dict import DictDefault + # Use a single patch for ring_flash_attn if it's not available ring_flash_attn_mock = MagicMock() with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): @@ -14,15 +16,38 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): from axolotl.utils.collators.batching import adjust_position_ids_for_slice -# Create a fixture for PartialState @pytest.fixture def partial_state(): """Create a real PartialState instance for testing.""" - # This initializes a PartialState for a non-distributed environment state = PartialState() return state +@pytest.fixture(name="cfg") +def fixture_cfg(): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-3, + "output_dir": "./model-out", + "sequence_len": 512, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } + ) + + return cfg + + class TestSequenceParallelHelpers: """Test helper functions used in sequence parallelism.""" @@ -95,7 +120,7 @@ class TestRingAttention: # Mock a simplified DataCollator test -@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group") +@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group") @patch("torch.distributed.get_rank") @patch("torch.distributed.get_world_size") def test_sequence_parallel_slicing( @@ -145,24 +170,36 @@ def test_sequence_parallel_slicing( assert torch.all(result["input_ids"] == expected_input_ids) -# Simple test for configuration validation -@pytest.mark.parametrize( - "config,should_validate", - [ - ({"sequence_parallel_size": 2, "flash_attention": True}, True), - ({"sequence_parallel_size": 2, "flash_attention": False}, False), - ({"sequence_parallel_size": 1, "flash_attention": False}, True), - ], -) -def test_sequence_parallel_config_requirements(config, should_validate): - """Test basic sequence parallelism configuration requirements.""" +def test_config_validation_with_valid_inputs(cfg): + """Test that valid sequence parallelism configurations pass validation.""" + # Import the actual model class with appropriate mocks + from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig - # Simple validation function that mimics the actual validator - def validate_sp_config(config): - if config.get("sequence_parallel_size", 1) > 1 and not config.get( - "flash_attention", False - ): - return False - return True + # Valid configuration: sequence_parallel_size > 1 and flash_attention is True + cfg = cfg | { + "sequence_parallel_size": 2, + "flash_attention": True, + } - assert validate_sp_config(config) == should_validate + # Should validate without errors + config = AxolotlInputConfig(**cfg) + assert config.sequence_parallel_size == 2 + assert config.flash_attention is True + + +def test_config_validation_with_invalid_inputs(cfg): + """Test that invalid sequence parallelism configurations fail validation.""" + from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + + # Invalid configuration: sequence_parallel_size > 1 but flash_attention is False + cfg = cfg | { + "sequence_parallel_size": 2, + "flash_attention": False, + } + + # Should raise ValidationError + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + + # Verify error message + assert "flash_attention: true must be set" in str(excinfo.value)