From ab49d16e34fb9dad838fd9f8a4b0e2781223b20f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 16:33:30 -0400 Subject: [PATCH] Dion optimizer support (#3014) * Add support for Dion optimizer * dion training kwargs * fix var names * no dion 8bit for now * use updated axolotl-contribs-mit for dion optimizer * add smoke test for dion optimizer * add docs * fix typo during edits * fix test to not remove load in 8bit --- _quarto.yml | 1 + docs/nd_parallelism.qmd | 2 +- docs/optimizers.qmd | 18 +++++++++++ requirements.txt | 2 +- src/axolotl/core/builders/base.py | 21 ++++++++++++ src/axolotl/core/training_args_base.py | 15 +++++++++ src/axolotl/integrations/base.py | 23 ++++++++++++++ src/axolotl/utils/schemas/enums.py | 1 + src/axolotl/utils/schemas/training.py | 20 ++++++++++++ tests/e2e/test_optimizers.py | 44 ++++++++++++++++++++++++++ 10 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 docs/optimizers.qmd diff --git a/_quarto.yml b/_quarto.yml index 738fe5e2f..5bb771c01 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -284,6 +284,7 @@ website: - docs/sequence_parallelism.qmd - docs/gradient_checkpointing.qmd - docs/nd_parallelism.qmd + - docs/optimizers.qmd - section: "Troubleshooting" contents: diff --git a/docs/nd_parallelism.qmd b/docs/nd_parallelism.qmd index d27a15663..8aebab140 100644 --- a/docs/nd_parallelism.qmd +++ b/docs/nd_parallelism.qmd @@ -1,5 +1,5 @@ --- -title: "N-D Parallelism" +title: "N-D Parallelism (Beta)" --- Axolotl enables training models at scale by composing different parallelism techniques. This is essential when: diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd new file mode 100644 index 000000000..563e9695b --- /dev/null +++ b/docs/optimizers.qmd @@ -0,0 +1,18 @@ +--- +title: Optimizers +description: Configuring optimizers +--- + +### Dion Optimizer + +Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient +orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication. + +Usage: + +```yaml +optimizer: dion +dion_lr: 0.01 +dion_momentum: 0.95 +lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW +``` diff --git a/requirements.txt b/requirements.txt index 4e82dfd89..cd9b2cf62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -66,6 +66,6 @@ torchao==0.12.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 -axolotl-contribs-mit==0.0.3 +axolotl-contribs-mit==0.0.4 mistral-common==1.8.3 diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index dbdda7a7c..5a25c1834 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -267,6 +267,17 @@ class TrainerBuilderBase(abc.ABC): optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "dion": + from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module + DionOptimizerFactory, + ) + + optimizer_cls = DionOptimizerFactory + optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] + optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] + optimizer_kwargs.update(adam_kwargs) + partial_state = PartialState() + optimizer_kwargs["device_mesh"] = partial_state.device_mesh elif self.cfg.optimizer == "optimi_adamw": from optimi import AdamW @@ -516,10 +527,20 @@ class TrainerBuilderBase(abc.ABC): "include_tokens_per_second", "weight_decay", "seed", + "dion_momentum", + "dion_rank_fraction", + "dion_rank_multiple_of", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + arg_map = { + "dion_learning_rate": "dion_lr", + } + for kwarg, cfg_arg in arg_map.items(): + if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None: + training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg) + training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size training_args_kwargs["average_tokens_across_devices"] = False diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 66649deef..fd0859ae9 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -243,3 +243,18 @@ class AxolotlTrainingMixins: ) # end of multi-modal section + + dion_learning_rate: float | None = field( + default=None, + metadata={"help": "The learning rate for Dion"}, + ) + dion_momentum: float | None = field( + default=None, + metadata={"help": "The momentum for Dion"}, + ) + dion_rank_fraction: float | None = field( + default=None, + ) + dion_rank_multiple_of: int | None = field( + default=None, + ) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 7d9b6a6f9..f43031287 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -26,9 +26,11 @@ import traceback from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel +from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import PreTrainedModel, Trainer +from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -641,3 +643,24 @@ class BaseOptimizerFactory: self, opt_model, training_args, **optimizer_kwargs ) -> Optimizer | None: pass + + # duplicated from transformers + def get_decay_parameter_names(self, model) -> list[str]: + """ + Get all parameter names that weight decay will be applied to. + + This function filters out parameters in two ways: + 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) + 2. By parameter name patterns (containing 'bias', or variation of 'norm') + """ + forbidden_name_patterns = [ + r"bias", + r"layernorm", + r"rmsnorm", + r"(?:^|\.)norm(?:$|\.)", + r"_norm(?:$|\.)", + ] + decay_parameters = get_parameter_names( + model, [nn.LayerNorm], forbidden_name_patterns + ) + return decay_parameters diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 3c8828396..cf2a8b484 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum): adopt_adamw = "adopt_adamw" came_pytorch = "came_pytorch" muon = "muon" + dion = "dion" class RingAttnFunc(str, Enum): diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index 6ee863397..b1788dcaa 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel): adam_beta3: float | None = Field( default=None, json_schema_extra={"description": "only used for CAME Optimizer"} ) + + dion_lr: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer learning rate"} + ) + dion_momentum: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer momentum"} + ) + dion_rank_fraction: float | None = Field( + default=1.0, + json_schema_extra={ + "description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension." + }, + ) + dion_rank_multiple_of: int | None = Field( + default=1, + json_schema_extra={ + "description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding." + }, + ) + max_grad_norm: float | None = Field( default=None, json_schema_extra={"description": "Gradient clipping max norm"} ) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 1d233a201..987d86041 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -13,6 +13,7 @@ from .utils import ( check_model_output_exists, require_torch_2_5_1, require_torch_2_6_0, + require_torch_2_7_0, with_temp_dir, ) @@ -160,6 +161,49 @@ class TestCustomOptimizers(unittest.TestCase): check_model_output_exists(temp_dir, cfg) assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ + @with_temp_dir + @require_torch_2_7_0 + def test_dion(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "dion", + "dion_lr": 0.01, + "dion_momentum": 0.95, + "lr_scheduler": "cosine", + "weight_decay": 0.01, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert "Dion" in trainer.optimizer.optimizer.__class__.__name__ + @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): # pylint: disable=duplicate-code