fix batch size calculation

This commit is contained in:
Wing Lian
2023-05-31 14:11:32 -04:00
parent f94dd626f0
commit 5a631b305b
2 changed files with 6 additions and 3 deletions

View File

@@ -163,15 +163,17 @@ def train(
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False