use new upstream branches for nd-parallelism

This commit is contained in:
Wing Lian
2025-07-22 21:12:22 -04:00
parent 5f1a4306b0
commit 5c74bebfd0
22 changed files with 134 additions and 95 deletions

View File

@@ -64,7 +64,7 @@ def fixture_base_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
"context_parallel_size": 1,
"tensor_parallel_size": 1,
# Dtype
"fp16": False,

View File

@@ -67,7 +67,7 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"ring_attn_func": ring_attn_func,
"save_first_step": False,
}

View File

@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {

View File

@@ -111,7 +111,7 @@ class TestRingAttention:
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
context_parallel_size=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
@@ -156,24 +156,24 @@ class TestConfigValidation:
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
{"context_parallel_size": 2, "flash_attention": True},
{"context_parallel_size": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
# Default context_parallel_size
({}, {"context_parallel_size": 1}, True, None),
# Invalid: context_parallel_size > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
{"context_parallel_size": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
# Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
@@ -186,13 +186,13 @@ class TestConfigValidation:
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
@@ -204,7 +204,7 @@ class TestConfigValidation:
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
@@ -262,7 +262,7 @@ class TestConfigValidation:
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
@@ -282,7 +282,7 @@ class TestConfigValidation:
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}