Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"context_parallel_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
|
||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
|
||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_hopper
|
||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||
cfg = DictDefault(
|
||||
|
||||
69
tests/e2e/multigpu/test_tp.py
Normal file
69
tests/e2e/multigpu/test_tp.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""multigpu e2e test for tensor parallelism."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
"""Test class for Tensor Parallel functionality."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TP doesn't work with models with tied weights (embeddings)"
|
||||
)
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"tensor_parallel_size": 2,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
@@ -1,481 +0,0 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
import functools
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
# Create test tensors
|
||||
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
|
||||
attention_mask = torch.ones(batch_size, seq_len)
|
||||
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
|
||||
labels = input_ids.clone()
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Verify that RuntimeError is raised when no group is registered
|
||||
with pytest.raises(
|
||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
||||
):
|
||||
get_ring_attn_group()
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_register_ring_attn(
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
mock_rank.return_value = 3 # GPU #3
|
||||
mock_group = MagicMock()
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
# Clean up
|
||||
set_ring_attn_group(None)
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
"""Set up mocks for all tests in this class."""
|
||||
# Mock the ring_flash_attn module
|
||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
||||
|
||||
@pytest.fixture
|
||||
def base_cfg(self):
|
||||
"""Create a base configuration for testing."""
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_updates, expected_values, should_pass, error_msg",
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
"pad_to_sequence_len": True,
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"micro_batch_size must be set to 1",
|
||||
),
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"valid_config",
|
||||
"default_sp_degree",
|
||||
"without_flash_attention",
|
||||
"sample_packing_with_large_batch",
|
||||
"valid_grpo",
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg
|
||||
cfg.update(config_updates)
|
||||
|
||||
if should_pass:
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check expected values
|
||||
for key, value in expected_values.items():
|
||||
assert getattr(config, key) == value
|
||||
else:
|
||||
# Should raise exception
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
assert error_msg in str(excinfo.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ring_attn_func, sample_packing, expected_func",
|
||||
[
|
||||
(None, True, RingAttnFunc.VARLEN_LLAMA3),
|
||||
(None, False, RingAttnFunc.BATCH_RING),
|
||||
],
|
||||
ids=["default_with_sample_packing", "default_without_sample_packing"],
|
||||
)
|
||||
def test_ring_attn_func_validation(
|
||||
self, base_cfg, ring_attn_func, sample_packing, expected_func
|
||||
):
|
||||
"""Test ring_attn_func validation and defaults."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
|
||||
if ring_attn_func is not None:
|
||||
cfg["ring_attn_func"] = ring_attn_func
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check ring_attn_func value
|
||||
assert config.ring_attn_func.value == expected_func
|
||||
|
||||
def test_invalid_ring_attn_func(self, base_cfg):
|
||||
"""Test that an invalid ring_attn_func is rejected."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
|
||||
# Should raise ValidationError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
"""Mock torch.distributed functions for testing."""
|
||||
# Mock is_initialized to return True
|
||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
||||
|
||||
# Mock get_rank to return 0 by default
|
||||
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
|
||||
|
||||
# Mock get_world_size to return 2 by default
|
||||
monkeypatch.setattr(
|
||||
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
|
||||
)
|
||||
|
||||
# Mock the process group
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
# Mock update_ring_attn_params
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Check that sequence dimension was sharded correctly
|
||||
assert result["input_ids"].shape[1] == seq_len // 2
|
||||
assert result["attention_mask"].shape[1] == seq_len // 2
|
||||
|
||||
# Verify content: rank 0 should get the first half of the sequence
|
||||
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
|
||||
assert torch.equal(
|
||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verify content: rank 1 should get the second half of the sequence
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Checks for both ranks
|
||||
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
||||
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
||||
|
||||
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
|
||||
# # Rank 0 should get chunks [0, 1] and [6, 7]
|
||||
# # Rank 1 should get chunks [2, 3] and [4, 5]
|
||||
# if seq_len == 8:
|
||||
# # Create expected tensors for comparison
|
||||
# rank0_expected = torch.cat(
|
||||
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
||||
# )
|
||||
|
||||
# rank1_expected = torch.cat(
|
||||
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
||||
# )
|
||||
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Use the partially applied function
|
||||
result, _, _ = rank0_ring_parallel(batch=batch)
|
||||
|
||||
# Verify it works as expected
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
assert torch.equal(
|
||||
result["input_ids"],
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verification should pass
|
||||
assert "position_ids" in result
|
||||
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
@@ -52,6 +52,8 @@ class TestLoadModelUtils:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"tensor_parallel_size": 1,
|
||||
"context_parallel_size": 1,
|
||||
}
|
||||
)
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -142,6 +142,10 @@ def is_hopper():
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
|
||||
def require_hopper(test_case):
|
||||
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
) -> None:
|
||||
|
||||
@@ -171,3 +171,44 @@ class TestModelsUtils:
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
|
||||
[
|
||||
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
|
||||
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
|
||||
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
|
||||
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_parallel_config_kwargs(
|
||||
self,
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = (
|
||||
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
)
|
||||
|
||||
if expected[0] > 1:
|
||||
assert res["tp_size"] == expected[0]
|
||||
if expected[1] > 1:
|
||||
assert res["cp_size"] == expected[1]
|
||||
if expected[2] > 1:
|
||||
assert res["dp_shard_size"] == expected[2]
|
||||
if expected[3] > 1:
|
||||
assert res["dp_replicate_size"] == expected[3]
|
||||
|
||||
@@ -26,32 +26,6 @@ class TestFSDPValidation:
|
||||
assert cfg.fsdp_version == 2
|
||||
assert cfg.fsdp_config.fsdp_version is None
|
||||
|
||||
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
# test w/o prefix too
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
|
||||
Reference in New Issue
Block a user