fix: use dp_world_size instead of world_size for batch_size with tensor parallelism (#3462) [skip ci]
This commit is contained in:
@@ -119,7 +119,11 @@ def normalize_config(cfg):
|
|||||||
if cfg.world_size != 1:
|
if cfg.world_size != 1:
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||||
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
|
effective_world_size = (
|
||||||
|
cfg.world_size
|
||||||
|
// (cfg.context_parallel_size or 1)
|
||||||
|
// (cfg.tensor_parallel_size or 1)
|
||||||
|
)
|
||||||
cfg.batch_size = cfg.batch_size * effective_world_size
|
cfg.batch_size = cfg.batch_size * effective_world_size
|
||||||
|
|
||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
|
|||||||
@@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.tensor_parallel_size
|
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||||
@@ -496,9 +495,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}")
|
LOG.debug(f"data_loader_len: {data_loader_len}")
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||||
math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size)
|
|
||||||
)
|
|
||||||
if cfg.dataloader_drop_last:
|
if cfg.dataloader_drop_last:
|
||||||
# drop the last batch for each epoch
|
# drop the last batch for each epoch
|
||||||
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||||
@@ -519,12 +516,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
len(train_dataset)
|
|
||||||
* cfg.num_epochs
|
|
||||||
* cfg.tensor_parallel_size
|
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}")
|
LOG.debug(f"total_num_steps: {total_num_steps}")
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
54
tests/test_tensor_parallel_batch_size.py
Normal file
54
tests/test_tensor_parallel_batch_size.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Tests for batch_size calculation with tensor parallelism."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import addict
|
||||||
|
import pytest
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="tp_base_cfg")
|
||||||
|
def fixture_tp_base_cfg(min_base_cfg):
|
||||||
|
return (
|
||||||
|
DictDefault(
|
||||||
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
sequence_len=2048,
|
||||||
|
num_epochs=1,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorParallelBatchSize:
|
||||||
|
"""Verify batch_size scales by effective dp world_size when using tensor parallelism."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"world_size, tensor_parallel_size, expected_batch_size",
|
||||||
|
[
|
||||||
|
(4, 1, 32), # no TP: 2*4*4 = 32
|
||||||
|
(4, 2, 16), # TP=2: 2*4*(4//2) = 16
|
||||||
|
(4, 4, 8), # TP=4: 2*4*(4//4) = 8
|
||||||
|
(2, 2, 8), # TP=ws: 2*4*(2//2) = 8 (no scaling)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_size_with_tensor_parallelism(
|
||||||
|
self,
|
||||||
|
tp_base_cfg,
|
||||||
|
monkeypatch,
|
||||||
|
world_size,
|
||||||
|
tensor_parallel_size,
|
||||||
|
expected_batch_size,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("WORLD_SIZE", str(world_size))
|
||||||
|
tp_base_cfg["tensor_parallel_size"] = tensor_parallel_size
|
||||||
|
cfg = validate_config(tp_base_cfg)
|
||||||
|
# Mock load_model_config to avoid downloading the model and to bypass
|
||||||
|
# the tie_word_embeddings validation that blocks TP > 1.
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.config.load_model_config",
|
||||||
|
return_value=addict.Dict({"model_type": "llama"}),
|
||||||
|
):
|
||||||
|
normalize_config(cfg)
|
||||||
|
assert cfg.batch_size == expected_batch_size
|
||||||
Reference in New Issue
Block a user