fix ddp for incorrect steps (#2915)

* fix ddp for incorrect steps

* add test
This commit is contained in:
Wing Lian
2025-07-14 07:51:16 -04:00
committed by GitHub
parent 9a8073e73d
commit 41664c7c4c
2 changed files with 45 additions and 0 deletions

44
tests/test_train.py Normal file
View File

@@ -0,0 +1,44 @@
"""Test for batch size calculation for multi-gpu training."""
import pytest
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="train_base_cfg")
def fixture_train_base_cfg():
return DictDefault(
base_model="gpt2",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
sample_packing=True,
num_epochs=1,
)
class TestTrain:
"""test class for train related tests"""
@pytest.mark.parametrize(
"world_size, expected_batch_size",
[
(1, 8),
(4, 32),
],
)
def test_batch_size_ddp(
self, train_base_cfg, monkeypatch, world_size, expected_batch_size
):
monkeypatch.setenv("WORLD_SIZE", str(world_size))
cfg = validate_config(train_base_cfg)
normalize_config(cfg)
assert cfg.batch_size == expected_batch_size