Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
|
||||
Reference in New Issue
Block a user