feat: add complete optimizer docs (#3017) [skip ci]
* feat: add complete optimizer docs * fix: deprecate old torchao adamw low bit
This commit is contained in:
@@ -274,6 +274,7 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -284,7 +285,6 @@ website:
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
- docs/optimizers.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -3,12 +3,123 @@ title: Optimizers
|
||||
description: Configuring optimizers
|
||||
---
|
||||
|
||||
### Dion Optimizer
|
||||
## Overview
|
||||
|
||||
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||
|
||||
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||
|
||||
- `adamw_torch`
|
||||
- `adamw_torch_fused`
|
||||
- `adamw_torch_xla`
|
||||
- `adamw_torch_npu_fused`
|
||||
- `adamw_apex_fused`
|
||||
- `adafactor`
|
||||
- `adamw_anyprecision`
|
||||
- `adamw_torch_4bit`
|
||||
- `adamw_torch_8bit`
|
||||
- `ademamix`
|
||||
- `sgd`
|
||||
- `adagrad`
|
||||
- `adamw_bnb_8bit`
|
||||
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||
- `ademamix_8bit`
|
||||
- `lion_8bit`
|
||||
- `lion_32bit`
|
||||
- `paged_adamw_32bit`
|
||||
- `paged_adamw_8bit`
|
||||
- `paged_ademamix_32bit`
|
||||
- `paged_ademamix_8bit`
|
||||
- `paged_lion_32bit`
|
||||
- `paged_lion_8bit`
|
||||
- `rmsprop`
|
||||
- `rmsprop_bnb`
|
||||
- `rmsprop_bnb_8bit`
|
||||
- `rmsprop_bnb_32bit`
|
||||
- `galore_adamw`
|
||||
- `galore_adamw_8bit`
|
||||
- `galore_adafactor`
|
||||
- `galore_adamw_layerwise`
|
||||
- `galore_adamw_8bit_layerwise`
|
||||
- `galore_adafactor_layerwise`
|
||||
- `lomo`
|
||||
- `adalomo`
|
||||
- `grokadamw`
|
||||
- `schedule_free_radam`
|
||||
- `schedule_free_adamw`
|
||||
- `schedule_free_sgd`
|
||||
- `apollo_adamw`
|
||||
- `apollo_adamw_layerwise`
|
||||
- `stable_adamw`
|
||||
|
||||
|
||||
## Custom Optimizers
|
||||
|
||||
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||
|
||||
### optimi_adamw
|
||||
|
||||
```yaml
|
||||
optimizer: optimi_adamw
|
||||
```
|
||||
|
||||
### ao_adamw_4bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_4bit`.
|
||||
|
||||
### ao_adamw_8bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_8bit`.
|
||||
|
||||
### ao_adamw_fp8
|
||||
|
||||
|
||||
```yaml
|
||||
optimizer: ao_adamw_fp8
|
||||
```
|
||||
|
||||
### adopt_adamw
|
||||
|
||||
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||
|
||||
```yaml
|
||||
optimizer: adopt_adamw
|
||||
```
|
||||
|
||||
### came_pytorch
|
||||
|
||||
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||
|
||||
```yaml
|
||||
optimizer: came_pytorch
|
||||
|
||||
# optional args (defaults below)
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_beta3: 0.9999
|
||||
adam_epsilon: 1e-30
|
||||
adam_epsilon2: 1e-16
|
||||
```
|
||||
|
||||
### muon
|
||||
|
||||
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||
|
||||
```yaml
|
||||
optimizer: muon
|
||||
```
|
||||
|
||||
### dion
|
||||
|
||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||
|
||||
Usage:
|
||||
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||
|
||||
```yaml
|
||||
optimizer: dion
|
||||
|
||||
@@ -29,7 +29,6 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
@@ -284,21 +283,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
|
||||
Reference in New Issue
Block a user