Dion optimizer support (#3014)

* Add support for Dion optimizer

* dion training kwargs

* fix var names

* no dion 8bit for now

* use updated axolotl-contribs-mit for dion optimizer

* add smoke test for dion optimizer

* add docs

* fix typo during edits

* fix test to not remove load in 8bit
This commit is contained in:
Wing Lian
2025-08-04 16:33:30 -04:00
committed by GitHub
parent 33d094721c
commit ab49d16e34
10 changed files with 145 additions and 2 deletions

View File

@@ -284,6 +284,7 @@ website:
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- docs/optimizers.qmd
- section: "Troubleshooting"
contents:

View File

@@ -1,5 +1,5 @@
---
title: "N-D Parallelism"
title: "N-D Parallelism (Beta)"
---
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:

18
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,18 @@
---
title: Optimizers
description: Configuring optimizers
---
### Dion Optimizer
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
Usage:
```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```

View File

@@ -66,6 +66,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
axolotl-contribs-mit==0.0.4
mistral-common==1.8.3

View File

@@ -267,6 +267,17 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
@@ -516,10 +527,20 @@ class TrainerBuilderBase(abc.ABC):
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

View File

@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
)
# end of multi-modal section
dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)

View File

@@ -26,9 +26,11 @@ import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass
# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters

View File

@@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum):
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
class RingAttnFunc(str, Enum):

View File

@@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel):
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
dion_lr: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
)
dion_momentum: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
)
dion_rank_fraction: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
},
)
dion_rank_multiple_of: int | None = Field(
default=1,
json_schema_extra={
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
},
)
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)

View File

@@ -13,6 +13,7 @@ from .utils import (
check_model_output_exists,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
with_temp_dir,
)
@@ -160,6 +161,49 @@ class TestCustomOptimizers(unittest.TestCase):
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "dion",
"dion_lr": 0.01,
"dion_momentum": 0.95,
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code