diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 1c0e93e03..07c4d175f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -119,7 +119,11 @@ def normalize_config(cfg): if cfg.world_size != 1: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} if cfg.fsdp or cfg.fsdp_config or cfg.ddp: - effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1) + effective_world_size = ( + cfg.world_size + // (cfg.context_parallel_size or 1) + // (cfg.tensor_parallel_size or 1) + ) cfg.batch_size = cfg.batch_size * effective_world_size if not cfg.use_ray: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 320e59a90..91982137b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.tensor_parallel_size ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" @@ -496,9 +495,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): LOG.debug(f"data_loader_len: {data_loader_len}") # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int( - math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size) - ) + total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) if cfg.dataloader_drop_last: # drop the last batch for each epoch total_num_steps -= int(math.ceil(cfg.num_epochs)) @@ -519,12 +516,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") else: total_num_steps = int( - math.ceil( - len(train_dataset) - * cfg.num_epochs - * cfg.tensor_parallel_size - / cfg.batch_size - ) + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) LOG.debug(f"total_num_steps: {total_num_steps}") return total_num_steps diff --git a/tests/test_tensor_parallel_batch_size.py b/tests/test_tensor_parallel_batch_size.py new file mode 100644 index 000000000..f0b27a8eb --- /dev/null +++ b/tests/test_tensor_parallel_batch_size.py @@ -0,0 +1,54 @@ +"""Tests for batch_size calculation with tensor parallelism.""" + +from unittest.mock import patch + +import addict +import pytest +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="tp_base_cfg") +def fixture_tp_base_cfg(min_base_cfg): + return ( + DictDefault( + micro_batch_size=2, + gradient_accumulation_steps=4, + sequence_len=2048, + num_epochs=1, + ) + | min_base_cfg + ) + + +class TestTensorParallelBatchSize: + """Verify batch_size scales by effective dp world_size when using tensor parallelism.""" + + @pytest.mark.parametrize( + "world_size, tensor_parallel_size, expected_batch_size", + [ + (4, 1, 32), # no TP: 2*4*4 = 32 + (4, 2, 16), # TP=2: 2*4*(4//2) = 16 + (4, 4, 8), # TP=4: 2*4*(4//4) = 8 + (2, 2, 8), # TP=ws: 2*4*(2//2) = 8 (no scaling) + ], + ) + def test_batch_size_with_tensor_parallelism( + self, + tp_base_cfg, + monkeypatch, + world_size, + tensor_parallel_size, + expected_batch_size, + ): + monkeypatch.setenv("WORLD_SIZE", str(world_size)) + tp_base_cfg["tensor_parallel_size"] = tensor_parallel_size + cfg = validate_config(tp_base_cfg) + # Mock load_model_config to avoid downloading the model and to bypass + # the tie_word_embeddings validation that blocks TP > 1. + with patch( + "axolotl.utils.config.load_model_config", + return_value=addict.Dict({"model_type": "llama"}), + ): + normalize_config(cfg) + assert cfg.batch_size == expected_batch_size