eval dataloader and sampler changes
This commit is contained in:
@@ -7,6 +7,8 @@ import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
# Use a single patch for ring_flash_attn if it's not available
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
@@ -14,15 +16,38 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
|
||||
|
||||
# Create a fixture for PartialState
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
# This initializes a PartialState for a non-distributed environment
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
@@ -95,7 +120,7 @@ class TestRingAttention:
|
||||
|
||||
|
||||
# Mock a simplified DataCollator test
|
||||
@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group")
|
||||
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_sequence_parallel_slicing(
|
||||
@@ -145,24 +170,36 @@ def test_sequence_parallel_slicing(
|
||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||
|
||||
|
||||
# Simple test for configuration validation
|
||||
@pytest.mark.parametrize(
|
||||
"config,should_validate",
|
||||
[
|
||||
({"sequence_parallel_size": 2, "flash_attention": True}, True),
|
||||
({"sequence_parallel_size": 2, "flash_attention": False}, False),
|
||||
({"sequence_parallel_size": 1, "flash_attention": False}, True),
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_requirements(config, should_validate):
|
||||
"""Test basic sequence parallelism configuration requirements."""
|
||||
def test_config_validation_with_valid_inputs(cfg):
|
||||
"""Test that valid sequence parallelism configurations pass validation."""
|
||||
# Import the actual model class with appropriate mocks
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Simple validation function that mimics the actual validator
|
||||
def validate_sp_config(config):
|
||||
if config.get("sequence_parallel_size", 1) > 1 and not config.get(
|
||||
"flash_attention", False
|
||||
):
|
||||
return False
|
||||
return True
|
||||
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
}
|
||||
|
||||
assert validate_sp_config(config) == should_validate
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
assert config.sequence_parallel_size == 2
|
||||
assert config.flash_attention is True
|
||||
|
||||
|
||||
def test_config_validation_with_invalid_inputs(cfg):
|
||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"flash_attention": False,
|
||||
}
|
||||
|
||||
# Should raise ValidationError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "flash_attention: true must be set" in str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user