Files
axolotl/tests/test_train.py
2025-07-14 10:05:26 -04:00

40 lines
1013 B
Python

"""Test for batch size calculation for multi-gpu training."""
import pytest
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="train_base_cfg")
def fixture_train_base_cfg(min_base_cfg):
return (
DictDefault(
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
sample_packing=True,
num_epochs=1,
)
| min_base_cfg
)
class TestTrain:
"""test class for train related tests"""
@pytest.mark.parametrize(
"world_size, expected_batch_size",
[
(1, 8),
(4, 32),
],
)
def test_batch_size_ddp(
self, train_base_cfg, monkeypatch, world_size, expected_batch_size
):
monkeypatch.setenv("WORLD_SIZE", str(world_size))
cfg = validate_config(train_base_cfg)
normalize_config(cfg)
assert cfg.batch_size == expected_batch_size