more fixes for sample packing

This commit is contained in:
Wing Lian
2023-07-18 22:27:37 -04:00
parent 58045f0816
commit b02484a83e
2 changed files with 12 additions and 7 deletions

View File

@@ -141,7 +141,10 @@ class MultipackDistributedDataloader:
self.epoch = epoch self.epoch = epoch
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler] if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)

View File

@@ -113,7 +113,7 @@ class AxolotlTrainer(Trainer):
return self.accelerator.prepare( return self.accelerator.prepare(
MultipackDistributedDataloader( MultipackDistributedDataloader(
eval_dataset, eval_dataset,
batch_size=self.args.per_device_eval_batch_size, batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length, seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator, collate_fn=self.data_collator,
sampler=eval_sampler, sampler=eval_sampler,
@@ -162,12 +162,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# eval_dataset = eval_dataset.map(add_position_ids) # eval_dataset = eval_dataset.map(add_position_ids)
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = math.ceil( total_num_steps = (
total_num_tokens math.ceil(
/ cfg.sample_packing_eff_est total_num_tokens
/ 2048 / cfg.sample_packing_eff_est
/ 2048
/ cfg.batch_size
)
* cfg.num_epochs * cfg.num_epochs
/ cfg.batch_size
) )
else: else:
sampler = RandomSampler(train_dataset) sampler = RandomSampler(train_dataset)