finish basic impl; change naming from SP -> CP to match torch
This commit is contained in:
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"context_parallel_degree": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for sequence parallelism"""
|
||||
"""E2E tests for context parallelism"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
class TestContextParallelism:
|
||||
"""Test case for training with context parallelism enabled"""
|
||||
|
||||
def _run_sequence_parallel_test(
|
||||
def _run_context_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
@@ -24,7 +24,7 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
"""Helper method to run context parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -66,7 +66,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +109,7 @@ class TestSequenceParallelism:
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
def test_context_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
@@ -118,8 +118,8 @@ class TestSequenceParallelism:
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
"""Test context parallel training with different configurations"""
|
||||
self._run_context_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
|
||||
@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
"""Tests for context parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
|
||||
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
@@ -54,8 +54,8 @@ def fixture_cfg():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
def context_parallel_batch():
|
||||
"""Create a test batch for context parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
context_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class TestRingAttention:
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
"""Tests for validating context parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
@@ -155,24 +155,24 @@ class TestConfigValidation:
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
# Default context_parallel_degree
|
||||
({}, {"context_parallel_degree": 1}, True, None),
|
||||
# Invalid: context_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
{"context_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
@@ -185,32 +185,32 @@ class TestConfigValidation:
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@@ -222,10 +222,10 @@ class TestConfigValidation:
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
def test_context_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
"""Test various context parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
@@ -261,7 +261,7 @@ class TestConfigValidation:
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
@@ -281,7 +281,7 @@ class TestConfigValidation:
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
@@ -294,8 +294,8 @@ class TestConfigValidation:
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
class TestApplyContextParallelism:
|
||||
"""Tests for the apply_context_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
@@ -324,12 +324,12 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
batch=context_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -337,14 +337,14 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
assert result == context_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
@@ -366,11 +366,11 @@ class TestApplySequenceParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
@@ -386,14 +386,14 @@ class TestApplySequenceParallelism:
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# def test_batch_zigzag(self, context_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# batch = context_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# result_rank0 = apply_context_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
@@ -401,7 +401,7 @@ class TestApplySequenceParallelism:
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# result_rank1 = apply_context_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
@@ -430,12 +430,12 @@ class TestApplySequenceParallelism:
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
self, mock_get_ring_attn_group, context_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
batch = context_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
@@ -457,12 +457,10 @@ class TestApplySequenceParallelism:
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
def test_missing_position_ids(self, context_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
|
||||
Reference in New Issue
Block a user