Dion optimizer support (#3014)

* Add support for Dion optimizer

* dion training kwargs

* fix var names

* no dion 8bit for now

* use updated axolotl-contribs-mit for dion optimizer

* add smoke test for dion optimizer

* add docs

* fix typo during edits

* fix test to not remove load in 8bit
This commit is contained in:
Wing Lian
2025-08-04 16:33:30 -04:00
committed by GitHub
parent 33d094721c
commit ab49d16e34
10 changed files with 145 additions and 2 deletions

View File

@@ -267,6 +267,17 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
@@ -516,10 +527,20 @@ class TrainerBuilderBase(abc.ABC):
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

View File

@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
)
# end of multi-modal section
dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)

View File

@@ -26,9 +26,11 @@ import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass
# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters

View File

@@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum):
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
class RingAttnFunc(str, Enum):

View File

@@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel):
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
dion_lr: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
)
dion_momentum: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
)
dion_rank_fraction: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
},
)
dion_rank_multiple_of: int | None = Field(
default=1,
json_schema_extra={
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
},
)
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)