diff --git a/_quarto.yml b/_quarto.yml index 5bb771c01..934d393cb 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -274,6 +274,7 @@ website: - docs/dataset_preprocessing.qmd - docs/multipack.qmd - docs/mixed_precision.qmd + - docs/optimizers.qmd - section: "Advanced Features" contents: @@ -284,7 +285,6 @@ website: - docs/sequence_parallelism.qmd - docs/gradient_checkpointing.qmd - docs/nd_parallelism.qmd - - docs/optimizers.qmd - section: "Troubleshooting" contents: diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd index 563e9695b..45eea1d3a 100644 --- a/docs/optimizers.qmd +++ b/docs/optimizers.qmd @@ -3,12 +3,123 @@ title: Optimizers description: Configuring optimizers --- -### Dion Optimizer +## Overview + +Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187) + +Here is a list of optimizers supported by transformers as of `v4.54.0`: + +- `adamw_torch` +- `adamw_torch_fused` +- `adamw_torch_xla` +- `adamw_torch_npu_fused` +- `adamw_apex_fused` +- `adafactor` +- `adamw_anyprecision` +- `adamw_torch_4bit` +- `adamw_torch_8bit` +- `ademamix` +- `sgd` +- `adagrad` +- `adamw_bnb_8bit` +- `adamw_8bit` # alias for adamw_bnb_8bit +- `ademamix_8bit` +- `lion_8bit` +- `lion_32bit` +- `paged_adamw_32bit` +- `paged_adamw_8bit` +- `paged_ademamix_32bit` +- `paged_ademamix_8bit` +- `paged_lion_32bit` +- `paged_lion_8bit` +- `rmsprop` +- `rmsprop_bnb` +- `rmsprop_bnb_8bit` +- `rmsprop_bnb_32bit` +- `galore_adamw` +- `galore_adamw_8bit` +- `galore_adafactor` +- `galore_adamw_layerwise` +- `galore_adamw_8bit_layerwise` +- `galore_adafactor_layerwise` +- `lomo` +- `adalomo` +- `grokadamw` +- `schedule_free_radam` +- `schedule_free_adamw` +- `schedule_free_sgd` +- `apollo_adamw` +- `apollo_adamw_layerwise` +- `stable_adamw` + + +## Custom Optimizers + +Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below. + +### optimi_adamw + +```yaml +optimizer: optimi_adamw +``` + +### ao_adamw_4bit + +Deprecated: Please use `adamw_torch_4bit`. + +### ao_adamw_8bit + +Deprecated: Please use `adamw_torch_8bit`. + +### ao_adamw_fp8 + + +```yaml +optimizer: ao_adamw_fp8 +``` + +### adopt_adamw + +GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt) +Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853) + +```yaml +optimizer: adopt_adamw +``` + +### came_pytorch + +GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master) +Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047) + +```yaml +optimizer: came_pytorch + +# optional args (defaults below) +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_beta3: 0.9999 +adam_epsilon: 1e-30 +adam_epsilon2: 1e-16 +``` + +### muon + +Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/) +Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1) + +```yaml +optimizer: muon +``` + +### dion Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication. -Usage: +GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion) +Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295) +Note: Implementation written for PyTorch 2.7+ for DTensor ```yaml optimizer: dion diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 5a25c1834..0472acee9 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -29,7 +29,6 @@ from transformers import ( TrainerCallback, ) from transformers.trainer_pt_utils import AcceleratorConfig -from transformers.training_args import OptimizerNames from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr @@ -284,21 +283,6 @@ class TrainerBuilderBase(abc.ABC): optimizer_kwargs["foreach"] = False optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "ao_adamw_4bit": - # TODO remove 20250401 - from torchao.prototype.low_bit_optim import AdamW4bit - - optimizer_cls = AdamW4bit - optimizer_kwargs.update(adam_kwargs) - - LOG.warning( - f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." - ) - elif self.cfg.optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - optimizer_cls = AdamW8bit - optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "ao_adamw_fp8": from torchao.prototype.low_bit_optim import AdamWFp8