Compare commits
1 Commits
fused-mlp-
...
torch_tens
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41664c7c4c |
@@ -115,6 +115,7 @@ def normalize_config(cfg):
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.world_size != 1:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||
|
||||
44
tests/test_train.py
Normal file
44
tests/test_train.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Test for batch size calculation for multi-gpu training."""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="train_base_cfg")
|
||||
def fixture_train_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="gpt2",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
sequence_len=2048,
|
||||
sample_packing=True,
|
||||
num_epochs=1,
|
||||
)
|
||||
|
||||
|
||||
class TestTrain:
|
||||
"""test class for train related tests"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size, expected_batch_size",
|
||||
[
|
||||
(1, 8),
|
||||
(4, 32),
|
||||
],
|
||||
)
|
||||
def test_batch_size_ddp(
|
||||
self, train_base_cfg, monkeypatch, world_size, expected_batch_size
|
||||
):
|
||||
monkeypatch.setenv("WORLD_SIZE", str(world_size))
|
||||
cfg = validate_config(train_base_cfg)
|
||||
normalize_config(cfg)
|
||||
assert cfg.batch_size == expected_batch_size
|
||||
Reference in New Issue
Block a user