fix: use dp_world_size instead of world_size for batch_size with tensor parallelism (#3462) [skip ci]
This commit is contained in:
@@ -119,7 +119,11 @@ def normalize_config(cfg):
|
||||
if cfg.world_size != 1:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||
effective_world_size = cfg.world_size // (cfg.context_parallel_size or 1)
|
||||
effective_world_size = (
|
||||
cfg.world_size
|
||||
// (cfg.context_parallel_size or 1)
|
||||
// (cfg.tensor_parallel_size or 1)
|
||||
)
|
||||
cfg.batch_size = cfg.batch_size * effective_world_size
|
||||
|
||||
if not cfg.use_ray:
|
||||
|
||||
@@ -457,7 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -496,9 +495,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}")
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(data_loader_len * cfg.num_epochs * cfg.tensor_parallel_size)
|
||||
)
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
if cfg.dataloader_drop_last:
|
||||
# drop the last batch for each epoch
|
||||
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||
@@ -519,12 +516,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
Reference in New Issue
Block a user