From 4b8bc52424682faccc37774cc4460f142043339b Mon Sep 17 00:00:00 2001 From: Gilles Turpin Date: Thu, 5 Mar 2026 18:33:28 +0100 Subject: [PATCH] fix: correct total_num_steps and batch_size calculation with context parallelism (#3444) * fix: correct total_num_steps and batch_size calculation with context parallelism * feat: add test for CP batch size --------- Co-authored-by: NanoCode012 --- src/axolotl/utils/config/__init__.py | 3 +- src/axolotl/utils/trainer.py | 9 +--- tests/test_context_parallel_batch_size.py | 56 +++++++++++++++++++++++ 3 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 tests/test_context_parallel_batch_size.py diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 8b35ed406..1c0e93e03 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -119,7 +119,8 @@ def normalize_config(cfg): if cfg.world_size != 1: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} if cfg.fsdp or cfg.fsdp_config or cfg.ddp: - cfg.batch_size = cfg.batch_size * cfg.world_size + effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1) + cfg.batch_size = cfg.batch_size * effective_world_size if not cfg.use_ray: # delay resolving dtype until on worker node when launching with ray diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d97a74f6f..320e59a90 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.context_parallel_size * cfg.tensor_parallel_size ) LOG.debug( @@ -498,12 +497,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int( - math.floor( - data_loader_len - * cfg.num_epochs - * cfg.context_parallel_size - * cfg.tensor_parallel_size - ) + math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size) ) if cfg.dataloader_drop_last: # drop the last batch for each epoch @@ -528,7 +522,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.context_parallel_size * cfg.tensor_parallel_size / cfg.batch_size ) diff --git a/tests/test_context_parallel_batch_size.py b/tests/test_context_parallel_batch_size.py new file mode 100644 index 000000000..8f6ed7b28 --- /dev/null +++ b/tests/test_context_parallel_batch_size.py @@ -0,0 +1,56 @@ +"""Tests for batch_size calculation with context parallelism.""" + +import sys +import types + +import pytest + +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="cp_base_cfg") +def fixture_cp_base_cfg(min_base_cfg): + return ( + DictDefault( + micro_batch_size=2, + gradient_accumulation_steps=4, + sequence_len=2048, + num_epochs=1, + flash_attention=True, + ) + | min_base_cfg + ) + + +class TestContextParallelBatchSize: + """Verify batch_size scales by effective dp world_size when using context parallelism.""" + + @pytest.mark.parametrize( + "world_size, context_parallel_size, expected_batch_size", + [ + (4, 1, 32), # no CP: 2*4*4 = 32 + (4, 2, 16), # CP=2: 2*4*(4//2) = 16 + (4, 4, 8), # CP=4: 2*4*(4//4) = 8 + (2, 2, 8), # CP=ws: 2*4*(2//2) = 8 (no scaling) + ], + ) + def test_batch_size_with_context_parallelism( + self, + cp_base_cfg, + monkeypatch, + world_size, + context_parallel_size, + expected_batch_size, + ): + monkeypatch.setenv("WORLD_SIZE", str(world_size)) + # Mock ring_flash_attn since it's not installable on CPU, + # but required by schema validation when context_parallel_size > 1. + if "ring_flash_attn" not in sys.modules: + monkeypatch.setitem( + sys.modules, "ring_flash_attn", types.ModuleType("ring_flash_attn") + ) + cp_base_cfg["context_parallel_size"] = context_parallel_size + cfg = validate_config(cp_base_cfg) + normalize_config(cfg) + assert cfg.batch_size == expected_batch_size