Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
a6056e35de enable torch compile on the optimizer step
make optimizer compile independent of torch compile on the model
2025-06-10 00:07:49 -07:00
4 changed files with 16 additions and 2 deletions

View File

@@ -422,6 +422,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
if self.cfg.compile_optimizer:
training_args_kwargs["compile_optimizer"] = True
def _configure_gradient_checkpointing(self, training_args_kwargs: dict): def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = ( training_args_kwargs["gradient_checkpointing"] = (

View File

@@ -1,5 +1,6 @@
"""Module for Axolotl trainer optimizer mixin""" """Module for Axolotl trainer optimizer mixin"""
import torch
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -185,12 +186,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters() p.data_ptr(): p.numel() for p in module.parameters()
}.values() }.values()
) )
LOG.info(f"skipped {module}: {skipped/2**20}M params") LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override( manager.register_module_override(
module, "weight", {"optim_bits": 32} module, "weight", {"optim_bits": 32}
) )
LOG.debug(f"bitsandbytes: will optimize {module} in fp32") LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params") LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -199,6 +200,11 @@ class OptimizerMixin(Trainer):
return self.optimizer return self.optimizer
def create_optimizer_and_scheduler(self, num_training_steps: int):
super().create_optimizer_and_scheduler(num_training_steps)
if self.args.compile_optimizer:
self.optimizer.step = torch.compile(self.optimizer.step)
class OptimizerInitMixin: class OptimizerInitMixin:
""" """

View File

@@ -141,6 +141,10 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "absolute learning rate for the embedding layers."}, metadata={"help": "absolute learning rate for the embedding layers."},
) )
compile_optimizer: Optional[bool] = field(
default=None,
metadata={"help": "Whether to compile the optimizer for faster training."},
)
qlora: bool = field( qlora: bool = field(
default=False, default=False,
metadata={"help": "whether this is a qlora training"}, metadata={"help": "whether this is a qlora training"},

View File

@@ -275,6 +275,7 @@ class AxolotlInputConfig(
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = ( torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
None None
) )
compile_optimizer: bool | None = None
max_steps: int | None = None max_steps: int | None = None
warmup_steps: int | None = None warmup_steps: int | None = None