more fixes for sample packing

This commit is contained in:
Wing Lian
2023-07-18 22:27:37 -04:00
parent 58045f0816
commit b02484a83e
2 changed files with 12 additions and 7 deletions

View File

@@ -141,7 +141,10 @@ class MultipackDistributedDataloader:
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)

View File

@@ -113,7 +113,7 @@ class AxolotlTrainer(Trainer):
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
@@ -162,12 +162,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# eval_dataset = eval_dataset.map(add_position_ids)
if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = math.ceil(
total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
total_num_steps = (
math.ceil(
total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
)
* cfg.num_epochs
/ cfg.batch_size
)
else:
sampler = RandomSampler(train_dataset)