diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 907de056b..a334fd29b 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -422,6 +422,9 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.torch_compile_mode: training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode + if self.cfg.compile_optimizer: + training_args_kwargs["compile_optimizer"] = True + def _configure_gradient_checkpointing(self, training_args_kwargs: dict): if self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py index a9a9a3992..56301859c 100644 --- a/src/axolotl/core/trainers/mixins/optimizer.py +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -1,5 +1,6 @@ """Module for Axolotl trainer optimizer mixin""" +import torch from peft.optimizers import create_loraplus_optimizer from torch import nn from transformers.trainer import Trainer @@ -185,12 +186,12 @@ class OptimizerMixin(Trainer): p.data_ptr(): p.numel() for p in module.parameters() }.values() ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") + LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params") manager.register_module_override( module, "weight", {"optim_bits": 32} ) LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") + LOG.info(f"skipped: {skipped / 2 ** 20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -199,6 +200,11 @@ class OptimizerMixin(Trainer): return self.optimizer + def create_optimizer_and_scheduler(self, num_training_steps: int): + super().create_optimizer_and_scheduler(num_training_steps) + if self.args.compile_optimizer: + self.optimizer.step = torch.compile(self.optimizer.step) + class OptimizerInitMixin: """ diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 42488e643..e0832eace 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -141,6 +141,10 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "absolute learning rate for the embedding layers."}, ) + compile_optimizer: Optional[bool] = field( + default=None, + metadata={"help": "Whether to compile the optimizer for faster training."}, + ) qlora: bool = field( default=False, metadata={"help": "whether this is a qlora training"}, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e5f105053..0802868d6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -275,6 +275,7 @@ class AxolotlInputConfig( torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = ( None ) + compile_optimizer: bool | None = None max_steps: int | None = None warmup_steps: int | None = None