eval dataloader and sampler changes

This commit is contained in:
Dan Saunders
2025-03-13 19:24:30 +00:00
parent d0e178d52f
commit 4ff97bc9d4
2 changed files with 173 additions and 53 deletions

View File

@@ -398,7 +398,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
) )
return super()._wrap_model(model, training=training, dataloader=dataloader) return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler(self, base_sampler): def _create_multipack_sampler(self, base_sampler, dataset, group_size):
"""Helper method to create a MultipackBatchSampler""" """Helper method to create a MultipackBatchSampler"""
if self.args.multipack_real_batches: if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size batch_size = self.args.per_device_train_batch_size
@@ -412,11 +412,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return MultipackBatchSampler( return MultipackBatchSampler(
base_sampler, base_sampler,
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
batch_size=batch_size, batch_size=batch_size,
group_size=self.args.sample_packing_group_size, group_size=group_size,
bin_size=self.args.sample_packing_bin_size, bin_size=self.args.sample_packing_bin_size,
drop_last=True, drop_last=True,
) )
@@ -439,18 +439,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Apply multipack wrapper if needed # Apply multipack wrapper if needed
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
return self._create_multipack_sampler(base_sampler) return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size,
)
return base_sampler return base_sampler
# Handle non-SP mode
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
sampler = ( base_sampler = (
SequentialSampler(self.train_dataset) SequentialSampler(self.train_dataset)
if self.args.curriculum_sampling if self.args.curriculum_sampling
else RandomSampler(self.train_dataset) else RandomSampler(self.train_dataset)
) )
return self._create_multipack_sampler(sampler)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
group_size=self.args.sample_packing_group_size,
)
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset) return SequentialSampler(self.train_dataset)
@@ -458,27 +466,55 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler( def _get_eval_sampler(
self, eval_dataset: Dataset self, eval_dataset: Optional[Dataset] = None
) -> Optional[torch.utils.data.Sampler]: ) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False: """Get evaluation sampler"""
if self.args.multipack_real_batches: eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length # Handle sequence parallelism
else: if self.args.sequence_parallel_size > 1:
batch_size = 1 # Create sampler for SP groups
batch_max_len = ( num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
self.args.per_device_eval_batch_size * self.args.max_seq_length sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
)
return MultipackBatchSampler( # Create distributed sampler for the SP group
SequentialSampler(eval_dataset), base_sampler = torch.utils.data.distributed.DistributedSampler(
lengths=get_dataset_lengths(self.eval_dataset), eval_dataset,
packing_efficiency_estimate=self.args.sample_packing_efficiency, num_replicas=num_sp_groups,
batch_max_len=batch_max_len, rank=sp_group_id,
batch_size=batch_size, shuffle=False,
group_size=self.args.sample_packing_group_size, drop_last=False,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
) )
if self.args.sample_packing and self.args.eval_sample_packing is not False:
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return base_sampler
if self.args.sample_packing and self.args.eval_sample_packing is not False:
base_sampler = SequentialSampler(eval_dataset)
group_size = (
self.args.eval_packing_group_size
if hasattr(self.args, "eval_packing_group_size")
else self.args.sample_packing_group_size
)
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
group_size=group_size,
)
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
@@ -546,25 +582,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
return self.accelerator.prepare_data_loader(dataloader) return self.accelerator.prepare_data_loader(dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""Get dataloader for evaluation"""
if self.args.sample_packing and self.args.eval_sample_packing is False: if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator self.eval_data_collator
) )
if eval_dataset: if eval_dataset and "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"]) eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset) dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator self.train_data_collator
) )
return dataloader
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False: if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = ( eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
) )
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
# Only remove length column if it exists
if "length" in eval_dataset.features:
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator data_collator = self.data_collator
dataloader_params = { dataloader_params = {
"batch_size": self.args.eval_batch_size, "batch_size": self.args.eval_batch_size,
@@ -572,6 +613,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = ( dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor self.args.dataloader_prefetch_factor
@@ -585,9 +627,50 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params) # Create dataloader
dataloader = DataLoader(eval_dataset, **dataloader_params)
# Don't prepare dataloader for sequence parallelism
if self.args.sequence_parallel_size > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_size > 1:
# We need to customize the default dataloader for sequence parallelism
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
) )
data_collator = (
self.eval_data_collator
if self.eval_data_collator
else self.data_collator
)
# Handle dataset preprocessing as in the parent implementation
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="evaluation"
)
# Build dataloader parameters
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
sampler = self._get_eval_sampler(eval_dataset)
dataloader_params["sampler"] = sampler
# Create dataloader without accelerator preparation for sequence parallelism
return DataLoader(eval_dataset, **dataloader_params)
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)

View File

@@ -7,6 +7,8 @@ import pytest
import torch import torch
from accelerate.state import PartialState from accelerate.state import PartialState
from axolotl.utils.dict import DictDefault
# Use a single patch for ring_flash_attn if it's not available # Use a single patch for ring_flash_attn if it's not available
ring_flash_attn_mock = MagicMock() ring_flash_attn_mock = MagicMock()
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
@@ -14,15 +16,38 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
from axolotl.utils.collators.batching import adjust_position_ids_for_slice from axolotl.utils.collators.batching import adjust_position_ids_for_slice
# Create a fixture for PartialState
@pytest.fixture @pytest.fixture
def partial_state(): def partial_state():
"""Create a real PartialState instance for testing.""" """Create a real PartialState instance for testing."""
# This initializes a PartialState for a non-distributed environment
state = PartialState() state = PartialState()
return state return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
return cfg
class TestSequenceParallelHelpers: class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism.""" """Test helper functions used in sequence parallelism."""
@@ -95,7 +120,7 @@ class TestRingAttention:
# Mock a simplified DataCollator test # Mock a simplified DataCollator test
@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group") @patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank") @patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size") @patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing( def test_sequence_parallel_slicing(
@@ -145,24 +170,36 @@ def test_sequence_parallel_slicing(
assert torch.all(result["input_ids"] == expected_input_ids) assert torch.all(result["input_ids"] == expected_input_ids)
# Simple test for configuration validation def test_config_validation_with_valid_inputs(cfg):
@pytest.mark.parametrize( """Test that valid sequence parallelism configurations pass validation."""
"config,should_validate", # Import the actual model class with appropriate mocks
[ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
({"sequence_parallel_size": 2, "flash_attention": True}, True),
({"sequence_parallel_size": 2, "flash_attention": False}, False),
({"sequence_parallel_size": 1, "flash_attention": False}, True),
],
)
def test_sequence_parallel_config_requirements(config, should_validate):
"""Test basic sequence parallelism configuration requirements."""
# Simple validation function that mimics the actual validator # Valid configuration: sequence_parallel_size > 1 and flash_attention is True
def validate_sp_config(config): cfg = cfg | {
if config.get("sequence_parallel_size", 1) > 1 and not config.get( "sequence_parallel_size": 2,
"flash_attention", False "flash_attention": True,
): }
return False
return True
assert validate_sp_config(config) == should_validate # Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_size == 2
assert config.flash_attention is True
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_size": 2,
"flash_attention": False,
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)