From 487abfc7698c2dfc71267dea9f3877a5a9040f5f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 26 Jul 2023 00:06:28 -0400 Subject: [PATCH] pass sample packing efficiency to training args --- src/axolotl/utils/trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3dacdef89..72724d60a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -47,6 +47,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use sample packing for efficient training."}, ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) max_seq_length: int = field( default=2048, metadata={"help": "The maximum sequence length the model can handle"}, @@ -109,6 +113,7 @@ class AxolotlTrainer(Trainer): seq_max_length=self.args.max_seq_length, collate_fn=self.data_collator, sampler=train_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, ) ) return super().get_train_dataloader() @@ -310,6 +315,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.save_safetensors: training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors + if cfg.sample_packing_eff_est: + training_arguments_kwargs[ + "sample_packing_efficiency" + ] = cfg.sample_packing_eff_est + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg # max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps max_seq_length=cfg.sequence_len,