Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
41664c7c4c fix ddp for incorrect steps (#2915)
* fix ddp for incorrect steps

* add test
2025-07-14 07:51:16 -04:00
2 changed files with 45 additions and 0 deletions

View File

@@ -115,6 +115,7 @@ def normalize_config(cfg):
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.world_size != 1:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:

44
tests/test_train.py Normal file
View File

@@ -0,0 +1,44 @@
"""Test for batch size calculation for multi-gpu training."""
import pytest
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="train_base_cfg")
def fixture_train_base_cfg():
return DictDefault(
base_model="gpt2",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
sample_packing=True,
num_epochs=1,
)
class TestTrain:
"""test class for train related tests"""
@pytest.mark.parametrize(
"world_size, expected_batch_size",
[
(1, 8),
(4, 32),
],
)
def test_batch_size_ddp(
self, train_base_cfg, monkeypatch, world_size, expected_batch_size
):
monkeypatch.setenv("WORLD_SIZE", str(world_size))
cfg = validate_config(train_base_cfg)
normalize_config(cfg)
assert cfg.batch_size == expected_batch_size