fix counts by accounting for num devices

This commit is contained in:
Wing Lian
2023-08-08 04:13:10 -04:00
parent 58e9dee204
commit 21d307b15b

View File

@@ -181,7 +181,7 @@ class AxolotlTrainer(Trainer):
sampler=train_sampler, sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -203,7 +203,7 @@ class AxolotlTrainer(Trainer):
sampler=eval_sampler, sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -299,7 +299,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
/ cfg.sample_packing_eff_est / cfg.sample_packing_eff_est
/ 2048 / 2048
// cfg.batch_size // cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
) )
- 1 - 1
) )
@@ -322,6 +321,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
sampler=sampler, sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()