From 21d307b15bc10de0965ccd5cee06260a3d8159f2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 8 Aug 2023 04:13:10 -0400 Subject: [PATCH] fix counts by accounting for num devices --- src/axolotl/utils/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f97618a73..9af8674d3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -181,7 +181,7 @@ class AxolotlTrainer(Trainer): sampler=train_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, - # device_count=int(os.environ.get("WORLD_SIZE", 1)), + device_count=int(os.environ.get("WORLD_SIZE", 1)), ) ) return super().get_train_dataloader() @@ -203,7 +203,7 @@ class AxolotlTrainer(Trainer): sampler=eval_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, - # device_count=int(os.environ.get("WORLD_SIZE", 1)), + device_count=int(os.environ.get("WORLD_SIZE", 1)), ) ) return super().get_eval_dataloader(eval_dataset) @@ -299,7 +299,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): / cfg.sample_packing_eff_est / 2048 // cfg.batch_size - // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 ) @@ -322,6 +321,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer): sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, + device_count=int(os.environ.get("WORLD_SIZE", 1)), ) data_loader_len = data_loader.len_w_stats() actual_eff = data_loader.efficiency()