SP GRPO support + batch SP fixes (#2643)

* ctx manager for SP

* updates

* update

* further simplifying

* simplifying

* simplifying

* reorg

* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* fix

* fixes for batch API funcs, simplify

* fix

* grpo sp support

* progress

* stronger subclassing of TRL GRPO trainer; custom distributed sampler

* subclassing constructor

* progress

* finalizing SP + GRPO trainer

* minimize diffs to GRPO trainer

* remove (most of) the custom GRPO trainer logic

* debug

* debug

* update

* update

* update

* progress

* cleanup

* cleanup

* minor changes

* update

* update

* update

* small changes

* updates

* cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint

* spacing

* cleanup; log in pydantic model config only on main process

* remove comment

* fix sp sampler, update to latest upstream code, doc

* add docs

* update quartodoc autodoc contents

* fix, simplifications

* fixes + simplifications

* review comments

* lint

* removing main process only logs in favor of #2608

* fixes, additional smoke test

* updatse

* more tests

* update

* fix grad accum bug (sort of)

* lint, tests

* todo
This commit is contained in:
Dan Saunders
2025-05-12 17:52:40 -04:00
committed by GitHub
parent 67c4ea9c7c
commit 80304c26a7
27 changed files with 1448 additions and 455 deletions

View File

@@ -25,6 +25,7 @@ class TestSequenceParallelism:
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=2.0,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
@@ -93,22 +94,22 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
@@ -118,6 +119,7 @@ class TestSequenceParallelism:
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
threshold,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
@@ -126,4 +128,5 @@ class TestSequenceParallelism:
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
threshold=threshold,
)