fix: correct total_num_steps and batch_size calculation with context parallelism (#3444)
* fix: correct total_num_steps and batch_size calculation with context parallelism * feat: add test for CP batch size --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -119,7 +119,8 @@ def normalize_config(cfg):
|
||||
if cfg.world_size != 1:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
|
||||
cfg.batch_size = cfg.batch_size * effective_world_size
|
||||
|
||||
if not cfg.use_ray:
|
||||
# delay resolving dtype until on worker node when launching with ray
|
||||
|
||||
@@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
@@ -498,12 +497,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size)
|
||||
)
|
||||
if cfg.dataloader_drop_last:
|
||||
# drop the last batch for each epoch
|
||||
@@ -528,7 +522,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
|
||||
56
tests/test_context_parallel_batch_size.py
Normal file
56
tests/test_context_parallel_batch_size.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for batch_size calculation with context parallelism."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="cp_base_cfg")
|
||||
def fixture_cp_base_cfg(min_base_cfg):
|
||||
return (
|
||||
DictDefault(
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
sequence_len=2048,
|
||||
num_epochs=1,
|
||||
flash_attention=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
|
||||
class TestContextParallelBatchSize:
|
||||
"""Verify batch_size scales by effective dp world_size when using context parallelism."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size, context_parallel_size, expected_batch_size",
|
||||
[
|
||||
(4, 1, 32), # no CP: 2*4*4 = 32
|
||||
(4, 2, 16), # CP=2: 2*4*(4//2) = 16
|
||||
(4, 4, 8), # CP=4: 2*4*(4//4) = 8
|
||||
(2, 2, 8), # CP=ws: 2*4*(2//2) = 8 (no scaling)
|
||||
],
|
||||
)
|
||||
def test_batch_size_with_context_parallelism(
|
||||
self,
|
||||
cp_base_cfg,
|
||||
monkeypatch,
|
||||
world_size,
|
||||
context_parallel_size,
|
||||
expected_batch_size,
|
||||
):
|
||||
monkeypatch.setenv("WORLD_SIZE", str(world_size))
|
||||
# Mock ring_flash_attn since it's not installable on CPU,
|
||||
# but required by schema validation when context_parallel_size > 1.
|
||||
if "ring_flash_attn" not in sys.modules:
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "ring_flash_attn", types.ModuleType("ring_flash_attn")
|
||||
)
|
||||
cp_base_cfg["context_parallel_size"] = context_parallel_size
|
||||
cfg = validate_config(cp_base_cfg)
|
||||
normalize_config(cfg)
|
||||
assert cfg.batch_size == expected_batch_size
|
||||
Reference in New Issue
Block a user