Dion optimizer support (#3014)
* Add support for Dion optimizer * dion training kwargs * fix var names * no dion 8bit for now * use updated axolotl-contribs-mit for dion optimizer * add smoke test for dion optimizer * add docs * fix typo during edits * fix test to not remove load in 8bit
This commit is contained in:
@@ -284,6 +284,7 @@ website:
|
|||||||
- docs/sequence_parallelism.qmd
|
- docs/sequence_parallelism.qmd
|
||||||
- docs/gradient_checkpointing.qmd
|
- docs/gradient_checkpointing.qmd
|
||||||
- docs/nd_parallelism.qmd
|
- docs/nd_parallelism.qmd
|
||||||
|
- docs/optimizers.qmd
|
||||||
|
|
||||||
- section: "Troubleshooting"
|
- section: "Troubleshooting"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "N-D Parallelism"
|
title: "N-D Parallelism (Beta)"
|
||||||
---
|
---
|
||||||
|
|
||||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||||
|
|||||||
18
docs/optimizers.qmd
Normal file
18
docs/optimizers.qmd
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
---
|
||||||
|
title: Optimizers
|
||||||
|
description: Configuring optimizers
|
||||||
|
---
|
||||||
|
|
||||||
|
### Dion Optimizer
|
||||||
|
|
||||||
|
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||||
|
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: dion
|
||||||
|
dion_lr: 0.01
|
||||||
|
dion_momentum: 0.95
|
||||||
|
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
||||||
|
```
|
||||||
@@ -66,6 +66,6 @@ torchao==0.12.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.4
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.8.3
|
||||||
|
|||||||
@@ -267,6 +267,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "dion":
|
||||||
|
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
||||||
|
DionOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = DionOptimizerFactory
|
||||||
|
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||||
|
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
partial_state = PartialState()
|
||||||
|
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -516,10 +527,20 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
"weight_decay",
|
"weight_decay",
|
||||||
"seed",
|
"seed",
|
||||||
|
"dion_momentum",
|
||||||
|
"dion_rank_fraction",
|
||||||
|
"dion_rank_multiple_of",
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
|
arg_map = {
|
||||||
|
"dion_learning_rate": "dion_lr",
|
||||||
|
}
|
||||||
|
for kwarg, cfg_arg in arg_map.items():
|
||||||
|
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||||
|
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||||
|
|
||||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
training_args_kwargs["average_tokens_across_devices"] = False
|
training_args_kwargs["average_tokens_across_devices"] = False
|
||||||
|
|
||||||
|
|||||||
@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# end of multi-modal section
|
# end of multi-modal section
|
||||||
|
|
||||||
|
dion_learning_rate: float | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The learning rate for Dion"},
|
||||||
|
)
|
||||||
|
dion_momentum: float | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The momentum for Dion"},
|
||||||
|
)
|
||||||
|
dion_rank_fraction: float | None = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
dion_rank_multiple_of: int | None = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,9 +26,11 @@ import traceback
|
|||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from transformers import PreTrainedModel, Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
|
|||||||
self, opt_model, training_args, **optimizer_kwargs
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
) -> Optimizer | None:
|
) -> Optimizer | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# duplicated from transformers
|
||||||
|
def get_decay_parameter_names(self, model) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all parameter names that weight decay will be applied to.
|
||||||
|
|
||||||
|
This function filters out parameters in two ways:
|
||||||
|
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||||
|
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
||||||
|
"""
|
||||||
|
forbidden_name_patterns = [
|
||||||
|
r"bias",
|
||||||
|
r"layernorm",
|
||||||
|
r"rmsnorm",
|
||||||
|
r"(?:^|\.)norm(?:$|\.)",
|
||||||
|
r"_norm(?:$|\.)",
|
||||||
|
]
|
||||||
|
decay_parameters = get_parameter_names(
|
||||||
|
model, [nn.LayerNorm], forbidden_name_patterns
|
||||||
|
)
|
||||||
|
return decay_parameters
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
adopt_adamw = "adopt_adamw"
|
adopt_adamw = "adopt_adamw"
|
||||||
came_pytorch = "came_pytorch"
|
came_pytorch = "came_pytorch"
|
||||||
muon = "muon"
|
muon = "muon"
|
||||||
|
dion = "dion"
|
||||||
|
|
||||||
|
|
||||||
class RingAttnFunc(str, Enum):
|
class RingAttnFunc(str, Enum):
|
||||||
|
|||||||
@@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel):
|
|||||||
adam_beta3: float | None = Field(
|
adam_beta3: float | None = Field(
|
||||||
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dion_lr: float | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
|
||||||
|
)
|
||||||
|
dion_momentum: float | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
|
||||||
|
)
|
||||||
|
dion_rank_fraction: float | None = Field(
|
||||||
|
default=1.0,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dion_rank_multiple_of: int | None = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
max_grad_norm: float | None = Field(
|
max_grad_norm: float | None = Field(
|
||||||
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .utils import (
|
|||||||
check_model_output_exists,
|
check_model_output_exists,
|
||||||
require_torch_2_5_1,
|
require_torch_2_5_1,
|
||||||
require_torch_2_6_0,
|
require_torch_2_6_0,
|
||||||
|
require_torch_2_7_0,
|
||||||
with_temp_dir,
|
with_temp_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,6 +161,49 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_dion(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "dion",
|
||||||
|
"dion_lr": 0.01,
|
||||||
|
"dion_momentum": 0.95,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
Reference in New Issue
Block a user