SP GRPO support + batch SP fixes (#2643)

* ctx manager for SP

* updates

* update

* further simplifying

* simplifying

* simplifying

* reorg

* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* fix

* fixes for batch API funcs, simplify

* fix

* grpo sp support

* progress

* stronger subclassing of TRL GRPO trainer; custom distributed sampler

* subclassing constructor

* progress

* finalizing SP + GRPO trainer

* minimize diffs to GRPO trainer

* remove (most of) the custom GRPO trainer logic

* debug

* debug

* update

* update

* update

* progress

* cleanup

* cleanup

* minor changes

* update

* update

* update

* small changes

* updates

* cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint

* spacing

* cleanup; log in pydantic model config only on main process

* remove comment

* fix sp sampler, update to latest upstream code, doc

* add docs

* update quartodoc autodoc contents

* fix, simplifications

* fixes + simplifications

* review comments

* lint

* removing main process only logs in favor of #2608

* fixes, additional smoke test

* updatse

* more tests

* update

* fix grad accum bug (sort of)

* lint, tests

* todo
This commit is contained in:
Dan Saunders
2025-05-12 17:52:40 -04:00
committed by GitHub
parent 67c4ea9c7c
commit 80304c26a7
27 changed files with 1448 additions and 455 deletions

View File

@@ -25,6 +25,7 @@ class TestSequenceParallelism:
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
@@ -93,22 +94,22 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
@@ -118,6 +119,7 @@ class TestSequenceParallelism:
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
@@ -126,4 +128,5 @@ class TestSequenceParallelism:
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
threshold=threshold,
)

View File

@@ -10,14 +10,15 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
@@ -62,12 +63,14 @@ def sequence_parallel_batch():
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
@@ -179,12 +182,44 @@ class TestConfigValidation:
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
@@ -256,7 +291,7 @@ class TestConfigValidation:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
@@ -290,10 +325,11 @@ class TestApplySequenceParallelism:
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
@@ -305,10 +341,11 @@ class TestApplySequenceParallelism:
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
@@ -328,57 +365,59 @@ class TestApplySequenceParallelism:
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
@@ -390,11 +429,12 @@ class TestApplySequenceParallelism:
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
@@ -412,13 +452,15 @@ class TestApplySequenceParallelism:
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2