This commit is contained in:
Dan Saunders
2025-04-24 16:19:47 +00:00
parent 072df89e0e
commit 3f1873cc62
3 changed files with 388 additions and 132 deletions

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin

View File

@@ -2,6 +2,7 @@
Module for Axolotl trainer sequence parallelism mixin and training context manager Module for Axolotl trainer sequence parallelism mixin and training context manager
""" """
import functools
import logging import logging
import torch import torch
@@ -20,6 +21,66 @@ from axolotl.monkeypatch.attention.ring_attn import (
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin: class SequenceParallelMixin:
""" """
Mixin class for sequence parallelism support in trainers. Mixin class for sequence parallelism support in trainers.
@@ -125,11 +186,20 @@ class SequenceParallelContextManager:
# Will store hook handles for removal # Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = [] self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self): def __enter__(self):
# Forward pre-hook to apply sequence parallelism # Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs): def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs # Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(kwargs) kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs return args, kwargs
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
@@ -155,61 +225,6 @@ class SequenceParallelContextManager:
handle.remove() handle.remove()
self.hook_handles = [] self.hook_handles = []
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
def gather_outputs(self, output): def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor.""" """Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.) # Handle different output formats (dict, tensor, etc.)

View File

@@ -2,14 +2,19 @@
# pylint: disable=redefined-outer-name,unused-argument # pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from accelerate.state import PartialState from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -47,6 +52,27 @@ def fixture_cfg():
return cfg return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention: class TestRingAttention:
"""Tests for the ring attention functionality.""" """Tests for the ring attention functionality."""
@@ -73,11 +99,6 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state self, mock_world_size, mock_rank, mock_new_group, partial_state
): ):
"""Test that ring attention groups are created correctly.""" """Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks # Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3 mock_rank.return_value = 3 # GPU #3
@@ -101,88 +122,308 @@ class TestRingAttention:
set_ring_attn_group(None) set_ring_attn_group(None)
# Mock a simplified DataCollator test class TestConfigValidation:
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group") """Tests for validating sequence parallelism configurations."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
# Create a sample batch @pytest.fixture(autouse=True)
batch = { def setup_mocks(self, monkeypatch):
"input_ids": torch.tensor( """Set up mocks for all tests in this class."""
[ # Mock the ring_flash_attn module
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112], monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
# Simplified slicing logic from SequenceParallelDataCollator # Mock the is_main_process function to return True
def slice_batch(batch, rank, world_size): monkeypatch.setattr(
result = {} "axolotl.utils.schemas.config.is_main_process", lambda: True
for key in batch: )
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
# Slice the batch @pytest.fixture
result = slice_batch( def base_cfg(self):
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value """Create a base configuration for testing."""
) return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
# Check slicing @pytest.mark.parametrize(
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU "config_updates, expected_values, should_pass, error_msg",
expected_input_ids = torch.tensor(
[ [
[104, 105, 106], # Second slice of first sequence # Valid configuration
[204, 205, 206], # Second slice of second sequence (
] {"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
) )
assert torch.all(result["input_ids"] == expected_input_ids) def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()}) class TestApplySequenceParallelism:
def test_config_validation_with_valid_inputs(cfg): """Tests for the apply_sequence_parallelism function."""
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True @pytest.fixture(autouse=True)
cfg = cfg | { def mock_distributed(self, monkeypatch):
"sequence_parallel_degree": 2, """Mock torch.distributed functions for testing."""
"flash_attention": True, # Mock is_initialized to return True
} monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Should validate without errors # Mock get_rank to return 0 by default
config = AxolotlInputConfig(**cfg) monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
def test_config_validation_with_invalid_inputs(cfg): # Mock the process group
"""Test that invalid sequence parallelism configurations fail validation.""" monkeypatch.setattr(
from axolotl.utils.schemas.config import AxolotlInputConfig "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
MagicMock,
)
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False # Mock update_ring_attn_params
cfg = cfg | { monkeypatch.setattr(
"sequence_parallel_degree": 2, "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
"flash_attention": False, lambda **kwargs: None,
} )
# Should raise ValidationError def test_world_size_one(self, sequence_parallel_batch):
with pytest.raises(ValueError) as excinfo: """Test that function returns original batch when world size is 1."""
AxolotlInputConfig(**cfg) result = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify error message # Should return the original batch unchanged
assert "flash_attention: true must be set" in str(excinfo.value) assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2