more fixes for sample packing
This commit is contained in:
@@ -141,7 +141,10 @@ class MultipackDistributedDataloader:
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = [idx for idx in self.sampler]
|
||||
if self.sampler:
|
||||
indices = [idx for idx in self.sampler]
|
||||
else:
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
@@ -113,7 +113,7 @@ class AxolotlTrainer(Trainer):
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
@@ -162,12 +162,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
# eval_dataset = eval_dataset.map(add_position_ids)
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||
total_num_steps = math.ceil(
|
||||
total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
total_num_steps = (
|
||||
math.ceil(
|
||||
total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
/ cfg.batch_size
|
||||
)
|
||||
* cfg.num_epochs
|
||||
/ cfg.batch_size
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
|
||||
Reference in New Issue
Block a user