From 1bcfc08c9008cd62acae146cf0872e5e0bdde09f Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal <119044997+OnePunchMonk@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:54:44 +0530 Subject: [PATCH] =?UTF-8?q?feat:=20add=20support=20and=20end-to-end=20test?= =?UTF-8?q?s=20for=20multiple=20custom=20optimizers=E2=80=A6=20(#3457)=20[?= =?UTF-8?q?skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW. * feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw. * feat: introduce Pydantic schema validation for dataset, attention, and training configurations. * feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers. * test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers. * test: fix assertion in flash optimizers test to compare class names directly * fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args * chore: lint --------- Co-authored-by: NanoCode012 --- src/axolotl/core/builders/base.py | 24 ++++++++++ src/axolotl/utils/schemas/enums.py | 5 +++ src/axolotl/utils/schemas/validation.py | 31 +++++++++++-- tests/e2e/test_optimizers.py | 58 +++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index a149566b3..5752a0584 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC): adam_kwargs["eps"] = (eps1, eps2) optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_adamw": + from flashoptim import FlashAdamW + + optimizer_cls = FlashAdamW + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_adam": + from flashoptim import FlashAdam + + optimizer_cls = FlashAdam + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_sgd": + from flashoptim import FlashSGD + + optimizer_cls = FlashSGD + elif self.cfg.optimizer == "flash_sgdw": + from flashoptim import FlashSGDW + + optimizer_cls = FlashSGDW + elif self.cfg.optimizer == "flash_lion": + from flashoptim import FlashLion + + optimizer_cls = FlashLion + if "betas" in adam_kwargs: + optimizer_kwargs["betas"] = adam_kwargs["betas"] else: raise ValueError( f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue." diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 792f6f6de..40fa314f4 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum): came_pytorch = "came_pytorch" muon = "muon" dion = "dion" + flash_adamw = "flash_adamw" + flash_adam = "flash_adam" + flash_sgd = "flash_sgd" + flash_sgdw = "flash_sgdw" + flash_lion = "flash_lion" class RingAttnFunc(str, Enum): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 5e6657a78..8ff61b370 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -790,6 +790,14 @@ class OptimizationValidationMixin: LOG.warning("adamw hyperparameters found, but no adamw optimizer set") return self + @staticmethod + def _resolve_fsdp_version(data): + """Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version.""" + fsdp_version = data.get("fsdp_version") + if fsdp_version is None: + fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + return fsdp_version + @model_validator(mode="before") @classmethod def check_muon_deepspeed_fsdp(cls, data): @@ -799,15 +807,32 @@ class OptimizationValidationMixin: "Muon optimizer is currently incompatible with DeepSpeed" ) if data.get("fsdp") or data.get("fsdp_config"): - fsdp_version = data.get("fsdp_version") - if fsdp_version is None: - fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + fsdp_version = cls._resolve_fsdp_version(data) if str(fsdp_version) != "2": raise ValueError( "Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP." ) return data + @model_validator(mode="before") + @classmethod + def check_flashoptim_deepspeed_fsdp(cls, data): + optimizer = data.get("optimizer") or "" + if str(optimizer).startswith("flash_"): + if data.get("deepspeed"): + raise ValueError( + f"{optimizer} optimizer is incompatible with DeepSpeed. " + "Flash optimizers only support DDP and FSDP2." + ) + if data.get("fsdp") or data.get("fsdp_config"): + fsdp_version = cls._resolve_fsdp_version(data) + if str(fsdp_version) != "2": + raise ValueError( + f"{optimizer} optimizer is only compatible with FSDP2. " + "Set fsdp_version: 2 to use flash optimizers with FSDP." + ) + return data + @model_validator(mode="before") @classmethod def check_batch_flattening_fa(cls, data): diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index de6c41fbe..40a536d4b 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama import unittest +import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -282,3 +284,59 @@ class TestCustomOptimizers(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + +@require_torch_2_7_0 +@pytest.mark.parametrize( + "optimizer_name,expected_class,learning_rate", + [ + ("flash_adamw", "FlashAdamW", 0.00001), + ("flash_adam", "FlashAdam", 0.00001), + ("flash_sgd", "FlashSGD", 0.01), + ("flash_sgdw", "FlashSGDW", 0.01), + ("flash_lion", "FlashLion", 0.0001), + ], +) +def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate): + temp_dir = str(tmp_path) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": learning_rate, + "optimizer": optimizer_name, + "max_steps": 5, + "lr_scheduler": "cosine", + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert trainer.optimizer.optimizer.__class__.__name__ == expected_class