diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index d507dace8..d8ba7f567 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -128,7 +128,7 @@ class MultipackDistributedDataloader: batch_size: int = 1, sampler: Union[Sampler, DistributedSampler] = None, packing_efficiency_estimate: float = 1.0, - seq_len_multiple: int = 1, + sample_packing_seq_len_multiplier: int = 1, ): # Dataset self.dataset = dataset @@ -136,10 +136,11 @@ class MultipackDistributedDataloader: [len(sample["input_ids"]) for sample in self.dataset] ) assert isinstance(self.lengths, np.ndarray) - assert batch_size % seq_len_multiple == 0 + assert batch_size % sample_packing_seq_len_multiplier == 0 + assert batch_size >= sample_packing_seq_len_multiplier self.sampler = sampler self.batch_size = batch_size - self.seq_len_multiple = seq_len_multiple + self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier self.seq_max_length = seq_max_length self.batch_max_length = batch_size * seq_max_length self.collate_fn = collate_fn @@ -166,7 +167,7 @@ class MultipackDistributedDataloader: lengths_cumsum=lengths_cumsum, rank=self.rank, # c=self.batch_max_length, - c=self.seq_max_length * self.seq_len_multiple, + c=self.seq_max_length * self.sample_packing_seq_len_multiplier, n=self.num_replicas, ) @@ -183,7 +184,9 @@ class MultipackDistributedDataloader: all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est() - for batches in chunk(all_batches, self.batch_size // self.seq_len_multiple): + for batches in chunk( + all_batches, self.batch_size // self.sample_packing_seq_len_multiplier + ): chunked_data = [] attn_mask_cum_idx = 0 for batch in batches: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1cc6ba57b..29b944542 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=2048, metadata={"help": "The maximum sequence length the model can handle"}, ) + sample_packing_seq_len_multiplier: int = field( + default=1, + metadata={"help": "the multiplier for the max len for packed sequences"}, + ) class AxolotlTrainer(Trainer): @@ -176,7 +180,7 @@ class AxolotlTrainer(Trainer): collate_fn=self.data_collator, sampler=train_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, - seq_len_multiple=2, + sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, ) ) return super().get_train_dataloader() @@ -197,7 +201,7 @@ class AxolotlTrainer(Trainer): collate_fn=self.data_collator, sampler=eval_sampler, packing_efficiency_estimate=self.args.sample_packing_efficiency, - seq_len_multiple=2, + sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, ) ) return super().get_eval_dataloader(eval_dataset) @@ -295,7 +299,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ), sampler=sampler, packing_efficiency_estimate=cfg.sample_packing_eff_est, - seq_len_multiple=2, + sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, ) data_loader_len = len(data_loader) LOG.info(f"data_loader_len: {data_loader_len}") @@ -430,6 +434,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, sample_packing=cfg.sample_packing if cfg.sample_packing else False, + sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier or 1, **training_arguments_kwargs, ) @@ -523,7 +528,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.collator_pad_to_longest: data_collator_kwargs["padding"] = "longest" else: - data_collator_kwargs["pad_to_multiple_of"] = 8 + # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html + data_collator_kwargs["pad_to_multiple_of"] = 64 if cfg.is_llama_derived_model and cfg.landmark_attention: from axolotl.monkeypatch.llama_landmark_attn import (