Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"ring_attn_func": ring_attn_func,
"save_first_step": False,
}
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(