Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_size": 1,
"tensor_parallel_size": 1,
# Dtype
"fp16": False,

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"ring_attn_func": ring_attn_func,
"save_first_step": False,
}
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(

View File

@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {

View File

@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
"""Test class for FP8 mixed precision with FSDP2 functionality."""
@require_torch_2_7_0
@require_hopper
def test_fp8_fsdp2_smoke(self, temp_dir):
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
cfg = DictDefault(

View File

@@ -0,0 +1,69 @@
"""multigpu e2e test for tensor parallelism."""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
class TestTensorParallel:
"""Test class for Tensor Parallel functionality."""
@pytest.mark.skip(
reason="TP doesn't work with models with tied weights (embeddings)"
)
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"tensor_parallel_size": 2,
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -1,481 +0,0 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"save_first_step": False,
}
)
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Verify that new_group was called
mock_new_group.assert_called()
# Clean up
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -52,6 +52,8 @@ class TestLoadModelUtils:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"tensor_parallel_size": 1,
"context_parallel_size": 1,
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init

View File

@@ -142,6 +142,10 @@ def is_hopper():
return compute_capability == (9, 0)
def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:

View File

@@ -171,3 +171,44 @@ class TestModelsUtils:
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
[
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
],
)
def test_get_parallel_config_kwargs(
self,
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
expected,
):
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
)
if expected[0] > 1:
assert res["tp_size"] == expected[0]
if expected[1] > 1:
assert res["cp_size"] == expected[1]
if expected[2] > 1:
assert res["dp_shard_size"] == expected[2]
if expected[3] > 1:
assert res["dp_replicate_size"] == expected[3]

View File

@@ -26,32 +26,6 @@ class TestFSDPValidation:
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
# test w/o prefix too
cfg = min_base_cfg | DictDefault(
fsdp_config={
"state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={