diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e3728cc8..fdf86e567 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): / cfg.sample_packing_eff_est / cfg.sequence_len // cfg.batch_size - // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 ) @@ -359,18 +358,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): train_dataset.remove_columns(["length"]), batch_sampler=sampler, ) - data_loader_len = len(data_loader) // cfg.batch_size + data_loader_len = len(data_loader) // ( + cfg.world_size * cfg.gradient_accumulation_steps + ) actual_eff = sampler.efficiency() LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int( - math.floor( - data_loader_len - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - ) - ) + total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -391,12 +386,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) else: total_num_steps = int( - math.ceil( - len(train_dataset) - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - / cfg.batch_size - ) + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) return total_num_steps