feat: add complete optimizer docs (#3017) [skip ci]
* feat: add complete optimizer docs * fix: deprecate old torchao adamw low bit
This commit is contained in:
@@ -274,6 +274,7 @@ website:
|
|||||||
- docs/dataset_preprocessing.qmd
|
- docs/dataset_preprocessing.qmd
|
||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
- docs/mixed_precision.qmd
|
||||||
|
- docs/optimizers.qmd
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
@@ -284,7 +285,6 @@ website:
|
|||||||
- docs/sequence_parallelism.qmd
|
- docs/sequence_parallelism.qmd
|
||||||
- docs/gradient_checkpointing.qmd
|
- docs/gradient_checkpointing.qmd
|
||||||
- docs/nd_parallelism.qmd
|
- docs/nd_parallelism.qmd
|
||||||
- docs/optimizers.qmd
|
|
||||||
|
|
||||||
- section: "Troubleshooting"
|
- section: "Troubleshooting"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -3,12 +3,123 @@ title: Optimizers
|
|||||||
description: Configuring optimizers
|
description: Configuring optimizers
|
||||||
---
|
---
|
||||||
|
|
||||||
### Dion Optimizer
|
## Overview
|
||||||
|
|
||||||
|
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||||
|
|
||||||
|
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||||
|
|
||||||
|
- `adamw_torch`
|
||||||
|
- `adamw_torch_fused`
|
||||||
|
- `adamw_torch_xla`
|
||||||
|
- `adamw_torch_npu_fused`
|
||||||
|
- `adamw_apex_fused`
|
||||||
|
- `adafactor`
|
||||||
|
- `adamw_anyprecision`
|
||||||
|
- `adamw_torch_4bit`
|
||||||
|
- `adamw_torch_8bit`
|
||||||
|
- `ademamix`
|
||||||
|
- `sgd`
|
||||||
|
- `adagrad`
|
||||||
|
- `adamw_bnb_8bit`
|
||||||
|
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||||
|
- `ademamix_8bit`
|
||||||
|
- `lion_8bit`
|
||||||
|
- `lion_32bit`
|
||||||
|
- `paged_adamw_32bit`
|
||||||
|
- `paged_adamw_8bit`
|
||||||
|
- `paged_ademamix_32bit`
|
||||||
|
- `paged_ademamix_8bit`
|
||||||
|
- `paged_lion_32bit`
|
||||||
|
- `paged_lion_8bit`
|
||||||
|
- `rmsprop`
|
||||||
|
- `rmsprop_bnb`
|
||||||
|
- `rmsprop_bnb_8bit`
|
||||||
|
- `rmsprop_bnb_32bit`
|
||||||
|
- `galore_adamw`
|
||||||
|
- `galore_adamw_8bit`
|
||||||
|
- `galore_adafactor`
|
||||||
|
- `galore_adamw_layerwise`
|
||||||
|
- `galore_adamw_8bit_layerwise`
|
||||||
|
- `galore_adafactor_layerwise`
|
||||||
|
- `lomo`
|
||||||
|
- `adalomo`
|
||||||
|
- `grokadamw`
|
||||||
|
- `schedule_free_radam`
|
||||||
|
- `schedule_free_adamw`
|
||||||
|
- `schedule_free_sgd`
|
||||||
|
- `apollo_adamw`
|
||||||
|
- `apollo_adamw_layerwise`
|
||||||
|
- `stable_adamw`
|
||||||
|
|
||||||
|
|
||||||
|
## Custom Optimizers
|
||||||
|
|
||||||
|
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||||
|
|
||||||
|
### optimi_adamw
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: optimi_adamw
|
||||||
|
```
|
||||||
|
|
||||||
|
### ao_adamw_4bit
|
||||||
|
|
||||||
|
Deprecated: Please use `adamw_torch_4bit`.
|
||||||
|
|
||||||
|
### ao_adamw_8bit
|
||||||
|
|
||||||
|
Deprecated: Please use `adamw_torch_8bit`.
|
||||||
|
|
||||||
|
### ao_adamw_fp8
|
||||||
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: ao_adamw_fp8
|
||||||
|
```
|
||||||
|
|
||||||
|
### adopt_adamw
|
||||||
|
|
||||||
|
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||||
|
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: adopt_adamw
|
||||||
|
```
|
||||||
|
|
||||||
|
### came_pytorch
|
||||||
|
|
||||||
|
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||||
|
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: came_pytorch
|
||||||
|
|
||||||
|
# optional args (defaults below)
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.999
|
||||||
|
adam_beta3: 0.9999
|
||||||
|
adam_epsilon: 1e-30
|
||||||
|
adam_epsilon2: 1e-16
|
||||||
|
```
|
||||||
|
|
||||||
|
### muon
|
||||||
|
|
||||||
|
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||||
|
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: muon
|
||||||
|
```
|
||||||
|
|
||||||
|
### dion
|
||||||
|
|
||||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||||
|
|
||||||
Usage:
|
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||||
|
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||||
|
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
optimizer: dion
|
optimizer: dion
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
from transformers.training_args import OptimizerNames
|
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
@@ -284,21 +283,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
optimizer_kwargs["foreach"] = False
|
optimizer_kwargs["foreach"] = False
|
||||||
optimizer_cls = AdamW
|
optimizer_cls = AdamW
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
|
||||||
# TODO remove 20250401
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW4bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
|
||||||
)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW8bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user