fix counts by accounting for num devices

This commit is contained in:
Wing Lian
2023-08-08 04:13:10 -04:00
parent 58e9dee204
commit 21d307b15b

View File

@@ -181,7 +181,7 @@ class AxolotlTrainer(Trainer):
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
@@ -203,7 +203,7 @@ class AxolotlTrainer(Trainer):
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
@@ -299,7 +299,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
/ cfg.sample_packing_eff_est
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)
@@ -322,6 +321,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()