eval dataloader and sampler changes
This commit is contained in:
@@ -398,7 +398,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _create_multipack_sampler(self, base_sampler):
|
def _create_multipack_sampler(self, base_sampler, dataset, group_size):
|
||||||
"""Helper method to create a MultipackBatchSampler"""
|
"""Helper method to create a MultipackBatchSampler"""
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
batch_size = self.args.per_device_train_batch_size
|
batch_size = self.args.per_device_train_batch_size
|
||||||
@@ -412,11 +412,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
base_sampler,
|
base_sampler,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
group_size=self.args.sample_packing_group_size,
|
group_size=group_size,
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
@@ -439,18 +439,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
return self._create_multipack_sampler(base_sampler)
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=self.train_dataset,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
)
|
||||||
|
|
||||||
return base_sampler
|
return base_sampler
|
||||||
|
|
||||||
# Handle non-SP mode
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
sampler = (
|
base_sampler = (
|
||||||
SequentialSampler(self.train_dataset)
|
SequentialSampler(self.train_dataset)
|
||||||
if self.args.curriculum_sampling
|
if self.args.curriculum_sampling
|
||||||
else RandomSampler(self.train_dataset)
|
else RandomSampler(self.train_dataset)
|
||||||
)
|
)
|
||||||
return self._create_multipack_sampler(sampler)
|
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=self.train_dataset,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
return SequentialSampler(self.train_dataset)
|
return SequentialSampler(self.train_dataset)
|
||||||
@@ -458,27 +466,55 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Dataset
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
"""Get evaluation sampler"""
|
||||||
if self.args.multipack_real_batches:
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
batch_size = self.args.per_device_eval_batch_size
|
|
||||||
batch_max_len = self.args.max_seq_length
|
# Handle sequence parallelism
|
||||||
else:
|
if self.args.sequence_parallel_size > 1:
|
||||||
batch_size = 1
|
# Create sampler for SP groups
|
||||||
batch_max_len = (
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
||||||
)
|
|
||||||
return MultipackBatchSampler(
|
# Create distributed sampler for the SP group
|
||||||
SequentialSampler(eval_dataset),
|
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
eval_dataset,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
num_replicas=num_sp_groups,
|
||||||
batch_max_len=batch_max_len,
|
rank=sp_group_id,
|
||||||
batch_size=batch_size,
|
shuffle=False,
|
||||||
group_size=self.args.sample_packing_group_size,
|
drop_last=False,
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
group_size = (
|
||||||
|
self.args.eval_packing_group_size
|
||||||
|
if hasattr(self.args, "eval_packing_group_size")
|
||||||
|
else self.args.sample_packing_group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=eval_dataset,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
|
group_size = (
|
||||||
|
self.args.eval_packing_group_size
|
||||||
|
if hasattr(self.args, "eval_packing_group_size")
|
||||||
|
else self.args.sample_packing_group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=eval_dataset,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
@@ -546,25 +582,30 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
|
"""Get dataloader for evaluation"""
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
if eval_dataset and "length" in eval_dataset.features:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
return dataloader
|
|
||||||
|
|
||||||
|
return dataloader
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
|
# Only remove length column if it exists
|
||||||
|
if "length" in eval_dataset.features:
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
data_collator = self.data_collator
|
data_collator = self.data_collator
|
||||||
dataloader_params = {
|
dataloader_params = {
|
||||||
"batch_size": self.args.eval_batch_size,
|
"batch_size": self.args.eval_batch_size,
|
||||||
@@ -572,6 +613,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = (
|
dataloader_params["prefetch_factor"] = (
|
||||||
self.args.dataloader_prefetch_factor
|
self.args.dataloader_prefetch_factor
|
||||||
@@ -585,9 +627,50 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
# Create dataloader
|
||||||
|
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
|
||||||
|
# Don't prepare dataloader for sequence parallelism
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
# We need to customize the default dataloader for sequence parallelism
|
||||||
|
eval_dataset = (
|
||||||
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
|
data_collator = (
|
||||||
|
self.eval_data_collator
|
||||||
|
if self.eval_data_collator
|
||||||
|
else self.data_collator
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle dataset preprocessing as in the parent implementation
|
||||||
|
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||||
|
eval_dataset = self._remove_unused_columns(
|
||||||
|
eval_dataset, description="evaluation"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(
|
||||||
|
data_collator, description="evaluation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build dataloader parameters
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self.args.per_device_eval_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader_params["sampler"] = sampler
|
||||||
|
|
||||||
|
# Create dataloader without accelerator preparation for sequence parallelism
|
||||||
|
return DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from accelerate.state import PartialState
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
# Use a single patch for ring_flash_attn if it's not available
|
# Use a single patch for ring_flash_attn if it's not available
|
||||||
ring_flash_attn_mock = MagicMock()
|
ring_flash_attn_mock = MagicMock()
|
||||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||||
@@ -14,15 +16,38 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
|||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||||
|
|
||||||
|
|
||||||
# Create a fixture for PartialState
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def partial_state():
|
def partial_state():
|
||||||
"""Create a real PartialState instance for testing."""
|
"""Create a real PartialState instance for testing."""
|
||||||
# This initializes a PartialState for a non-distributed environment
|
|
||||||
state = PartialState()
|
state = PartialState()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"output_dir": "./model-out",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelHelpers:
|
class TestSequenceParallelHelpers:
|
||||||
"""Test helper functions used in sequence parallelism."""
|
"""Test helper functions used in sequence parallelism."""
|
||||||
|
|
||||||
@@ -95,7 +120,7 @@ class TestRingAttention:
|
|||||||
|
|
||||||
|
|
||||||
# Mock a simplified DataCollator test
|
# Mock a simplified DataCollator test
|
||||||
@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group")
|
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
||||||
@patch("torch.distributed.get_rank")
|
@patch("torch.distributed.get_rank")
|
||||||
@patch("torch.distributed.get_world_size")
|
@patch("torch.distributed.get_world_size")
|
||||||
def test_sequence_parallel_slicing(
|
def test_sequence_parallel_slicing(
|
||||||
@@ -145,24 +170,36 @@ def test_sequence_parallel_slicing(
|
|||||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||||
|
|
||||||
|
|
||||||
# Simple test for configuration validation
|
def test_config_validation_with_valid_inputs(cfg):
|
||||||
@pytest.mark.parametrize(
|
"""Test that valid sequence parallelism configurations pass validation."""
|
||||||
"config,should_validate",
|
# Import the actual model class with appropriate mocks
|
||||||
[
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
({"sequence_parallel_size": 2, "flash_attention": True}, True),
|
|
||||||
({"sequence_parallel_size": 2, "flash_attention": False}, False),
|
|
||||||
({"sequence_parallel_size": 1, "flash_attention": False}, True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sequence_parallel_config_requirements(config, should_validate):
|
|
||||||
"""Test basic sequence parallelism configuration requirements."""
|
|
||||||
|
|
||||||
# Simple validation function that mimics the actual validator
|
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
|
||||||
def validate_sp_config(config):
|
cfg = cfg | {
|
||||||
if config.get("sequence_parallel_size", 1) > 1 and not config.get(
|
"sequence_parallel_size": 2,
|
||||||
"flash_attention", False
|
"flash_attention": True,
|
||||||
):
|
}
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
assert validate_sp_config(config) == should_validate
|
# Should validate without errors
|
||||||
|
config = AxolotlInputConfig(**cfg)
|
||||||
|
assert config.sequence_parallel_size == 2
|
||||||
|
assert config.flash_attention is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_validation_with_invalid_inputs(cfg):
|
||||||
|
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
|
||||||
|
cfg = cfg | {
|
||||||
|
"sequence_parallel_size": 2,
|
||||||
|
"flash_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should raise ValidationError
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
AxolotlInputConfig(**cfg)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert "flash_attention: true must be set" in str(excinfo.value)
|
||||||
|
|||||||
Reference in New Issue
Block a user