From ead34c516a22bf2ef6d8e35f8211b763ff0fdb58 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 22:16:24 -0500 Subject: [PATCH] swap the data collator for evals if not using sample packing (#1076) * swap the data collator for evals if not using sample packing * drop last from dataloader to help with issues with evals --- src/axolotl/core/trainer_builder.py | 46 ++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7798ca455..18dc353a2 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines """ Builder for the training args and trainer """ @@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer): args = None # type: AxolotlTrainingArguments tag_names = ["axolotl"] - def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): + def __init__( + self, + *_args, + num_epochs=1, + bench_data_collator=None, + eval_data_collator=None, + **kwargs + ): self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator - super().__init__(*args, **kwargs) + self.eval_data_collator = eval_data_collator + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None @@ -239,6 +249,16 @@ class AxolotlTrainer(Trainer): return super().get_train_dataloader() def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + if self.args.sample_packing and self.args.eval_sample_packing is False: + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.eval_data_collator + ) + dataloader = super().get_eval_dataloader(eval_dataset) + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.train_data_collator + ) + return dataloader + if self.args.sample_packing and self.args.eval_sample_packing is not False: eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset @@ -269,6 +289,7 @@ class AxolotlTrainer(Trainer): return self.accelerator.prepare_data_loader( DataLoader(eval_dataset, **dataloader_params) ) + return super().get_eval_dataloader(eval_dataset) def _get_bench_sampler( @@ -651,6 +672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "dataloader_prefetch_factor" ] = self.cfg.dataloader_prefetch_factor + if self.cfg.dataloader_drop_last is not None: + training_arguments_kwargs[ + "dataloader_drop_last" + ] = self.cfg.dataloader_drop_last + elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: + training_arguments_kwargs["dataloader_drop_last"] = True if self.cfg.val_set_size == 0: # no eval set, so don't eval @@ -831,6 +858,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): eval_dataset=self.eval_dataset, args=training_args, data_collator=self.build_collator(training_args, **data_collator_kwargs), + eval_data_collator=self.build_collator( + training_args, is_eval=True, **data_collator_kwargs + ), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -851,14 +881,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer - def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): + def build_collator( + self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs + ): if training_args.pretraining: return None if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer) - if training_args.sample_packing: + use_batch_sampler_collator = False + if is_eval is False and training_args.sample_packing: + use_batch_sampler_collator = True + if is_eval and training_args.eval_sample_packing: + use_batch_sampler_collator = True + + if use_batch_sampler_collator: return BatchSamplerDataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt",