Dion optimizer support (#3014)

* Add support for Dion optimizer

* dion training kwargs

* fix var names

* no dion 8bit for now

* use updated axolotl-contribs-mit for dion optimizer

* add smoke test for dion optimizer

* add docs

* fix typo during edits

* fix test to not remove load in 8bit
This commit is contained in:
Wing Lian
2025-08-04 16:33:30 -04:00
committed by GitHub
parent 33d094721c
commit ab49d16e34
10 changed files with 145 additions and 2 deletions

View File

@@ -13,6 +13,7 @@ from .utils import (
check_model_output_exists,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
with_temp_dir,
)
@@ -160,6 +161,49 @@ class TestCustomOptimizers(unittest.TestCase):
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "dion",
"dion_lr": 0.01,
"dion_momentum": 0.95,
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code