finish basic impl; change naming from SP -> CP to match torch

This commit is contained in:
Dan Saunders
2025-06-13 09:51:06 -04:00
parent aced809989
commit 7a88de4fa8
25 changed files with 525 additions and 488 deletions

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_degree": 1,
# Dtype
"fp16": False,
"bf16": False,

View File

@@ -1,4 +1,4 @@
"""E2E tests for sequence parallelism"""
"""E2E tests for context parallelism"""
from pathlib import Path
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
class TestContextParallelism:
"""Test case for training with context parallelism enabled"""
def _run_sequence_parallel_test(
def _run_context_parallel_test(
self,
temp_dir,
sample_packing=True,
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
"""Helper method to run context parallel tests with different configurations"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(
def test_context_parallel_training(
self,
temp_dir,
sample_packing,
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
"""Test context parallel training with different configurations"""
self._run_context_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,

View File

@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {

View File

@@ -1,4 +1,4 @@
"""Tests for sequence parallelism functionality."""
"""Tests for context parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@@ -54,8 +54,8 @@ def fixture_cfg():
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
def context_parallel_batch():
"""Create a test batch for context parallelism tests."""
batch_size = 1
seq_len = 8
@@ -110,7 +110,7 @@ class TestRingAttention:
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
context_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
@@ -126,7 +126,7 @@ class TestRingAttention:
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
"""Tests for validating context parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
@@ -155,24 +155,24 @@ class TestConfigValidation:
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
{"context_parallel_degree": 2, "flash_attention": True},
{"context_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
# Default context_parallel_degree
({}, {"context_parallel_degree": 1}, True, None),
# Invalid: context_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
{"context_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
@@ -185,32 +185,32 @@ class TestConfigValidation:
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
"GRPO + CP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
"GRPO + CP + Liger not currently supported",
),
],
ids=[
@@ -222,10 +222,10 @@ class TestConfigValidation:
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
def test_context_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
"""Test various context parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
@@ -261,7 +261,7 @@ class TestConfigValidation:
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
@@ -281,7 +281,7 @@ class TestConfigValidation:
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
@@ -294,8 +294,8 @@ class TestConfigValidation:
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
class TestApplyContextParallelism:
"""Tests for the apply_context_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
@@ -324,12 +324,12 @@ class TestApplySequenceParallelism:
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_context_parallelism(
batch=sequence_parallel_batch,
batch=context_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
@@ -337,14 +337,14 @@ class TestApplySequenceParallelism:
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
assert result == context_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_context_parallelism(
@@ -366,11 +366,11 @@ class TestApplySequenceParallelism:
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
@@ -386,14 +386,14 @@ class TestApplySequenceParallelism:
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# def test_batch_zigzag(self, context_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# batch = context_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# result_rank0 = apply_context_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
@@ -401,7 +401,7 @@ class TestApplySequenceParallelism:
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# result_rank1 = apply_context_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
@@ -430,12 +430,12 @@ class TestApplySequenceParallelism:
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
self, mock_get_ring_attn_group, context_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
batch = context_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
@@ -457,12 +457,10 @@ class TestApplySequenceParallelism:
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
def test_missing_position_ids(self, context_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing