diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 4de606565..4e26a257d 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -115,6 +115,7 @@ def normalize_config(cfg): "chrf", ] choose_device(cfg) + cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.world_size != 1: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} if cfg.fsdp or cfg.fsdp_config or cfg.ddp: diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 000000000..291e9136b --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,44 @@ +"""Test for batch size calculation for multi-gpu training.""" + +import pytest + +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="train_base_cfg") +def fixture_train_base_cfg(): + return DictDefault( + base_model="gpt2", + learning_rate=1e-3, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + micro_batch_size=2, + gradient_accumulation_steps=4, + sequence_len=2048, + sample_packing=True, + num_epochs=1, + ) + + +class TestTrain: + """test class for train related tests""" + + @pytest.mark.parametrize( + "world_size, expected_batch_size", + [ + (1, 8), + (4, 32), + ], + ) + def test_batch_size_ddp( + self, train_base_cfg, monkeypatch, world_size, expected_batch_size + ): + monkeypatch.setenv("WORLD_SIZE", str(world_size)) + cfg = validate_config(train_base_cfg) + normalize_config(cfg) + assert cfg.batch_size == expected_batch_size