better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

This commit is contained in:
Wing Lian
2023-07-25 10:22:05 -04:00
parent 32fed7039d
commit df3eb645da
2 changed files with 5 additions and 3 deletions

View File

@@ -193,11 +193,13 @@ class MultipackDistributedDataloader:
def __len__(self):
batches, _ = self.generate_batches()
return len(batches)
return (
len(batches) * 0.99
) # shave off 1% for dealing with variance in packing and dataset length
def num_batches(self):
batches, _ = self.generate_batches()
return len(batches)
return len(batches) * 0.99
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -320,7 +320,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0