more fixes for sample packing
This commit is contained in:
@@ -141,7 +141,10 @@ class MultipackDistributedDataloader:
|
|||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
indices = [idx for idx in self.sampler]
|
if self.sampler:
|
||||||
|
indices = [idx for idx in self.sampler]
|
||||||
|
else:
|
||||||
|
indices = range(0, len(self.dataset))
|
||||||
|
|
||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.accelerator.prepare(
|
return self.accelerator.prepare(
|
||||||
MultipackDistributedDataloader(
|
MultipackDistributedDataloader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
seq_max_length=self.args.max_seq_length,
|
seq_max_length=self.args.max_seq_length,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=self.data_collator,
|
||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
@@ -162,12 +162,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
# eval_dataset = eval_dataset.map(add_position_ids)
|
# eval_dataset = eval_dataset.map(add_position_ids)
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
total_num_steps = math.ceil(
|
total_num_steps = (
|
||||||
total_num_tokens
|
math.ceil(
|
||||||
/ cfg.sample_packing_eff_est
|
total_num_tokens
|
||||||
/ 2048
|
/ cfg.sample_packing_eff_est
|
||||||
|
/ 2048
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampler = RandomSampler(train_dataset)
|
sampler = RandomSampler(train_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user