diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 3c0ca77de..8ded23661 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -418,6 +418,9 @@ class TrainerBuilderBase(abc.ABC): torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access True ) + torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access + 256 + ) training_args_kwargs["torch_compile"] = self.cfg.torch_compile if self.cfg.torch_compile_backend: training_args_kwargs["torch_compile_backend"] = ( @@ -426,6 +429,10 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.torch_compile_mode: training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode + def _configure_accelerator_config(self, training_args_kwargs: dict): + if self.cfg.accelerator_config: + training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config + def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( @@ -510,5 +517,6 @@ class TrainerBuilderBase(abc.ABC): self._configure_scheduler(training_args_kwargs) self._configure_optimizer(training_args_kwargs, trainer_kwargs) self._configure_torch_compile(training_args_kwargs) + self._configure_accelerator_config(training_args_kwargs) return training_args_kwargs, trainer_kwargs diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 9fcd51c1d..00cee35a7 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -310,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.neftune_noise_alpha ) - if self.cfg.accelerator_config: - training_arguments_kwargs["accelerator_config"] = ( - self.cfg.accelerator_config - ) - if self.cfg.image_size: training_arguments_kwargs["image_size"] = self.cfg.image_size if self.cfg.image_resize_algorithm: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 81a2f5a45..6b2d30709 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -75,18 +75,6 @@ class AxolotlTrainer( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access - 256 - ) - model = torch.compile( - model, - backend=self.args.torch_compile_backend, - mode=self.args.torch_compile_mode, - ) - return super()._wrap_model(model, training=training, dataloader=dataloader) - def _create_multipack_sampler( self, base_sampler: Sampler, dataset: Dataset ) -> MultipackBatchSampler: