fix: correct total_num_steps and batch_size calculation with context parallelism (#3444)

* fix: correct total_num_steps and batch_size calculation with context parallelism

* feat: add test for CP batch size

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Gilles Turpin
2026-03-05 18:33:28 +01:00
committed by GitHub
parent 28cc085283
commit 4b8bc52424
3 changed files with 59 additions and 9 deletions

View File

@@ -119,7 +119,8 @@ def normalize_config(cfg):
if cfg.world_size != 1:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
cfg.batch_size = cfg.batch_size * cfg.world_size
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
cfg.batch_size = cfg.batch_size * effective_world_size
if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray

View File

@@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
LOG.debug(
@@ -498,12 +497,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
@@ -528,7 +522,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
/ cfg.batch_size
)