fix counts by accounting for num devices
This commit is contained in:
@@ -181,7 +181,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -203,7 +203,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
@@ -299,7 +299,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
/ cfg.sample_packing_eff_est
|
/ cfg.sample_packing_eff_est
|
||||||
/ 2048
|
/ 2048
|
||||||
// cfg.batch_size
|
// cfg.batch_size
|
||||||
// int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
@@ -322,6 +321,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||||
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
)
|
||||||
data_loader_len = data_loader.len_w_stats()
|
data_loader_len = data_loader.len_w_stats()
|
||||||
actual_eff = data_loader.efficiency()
|
actual_eff = data_loader.efficiency()
|
||||||
|
|||||||
Reference in New Issue
Block a user