diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index a149566b3..5752a0584 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC): adam_kwargs["eps"] = (eps1, eps2) optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_adamw": + from flashoptim import FlashAdamW + + optimizer_cls = FlashAdamW + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_adam": + from flashoptim import FlashAdam + + optimizer_cls = FlashAdam + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "flash_sgd": + from flashoptim import FlashSGD + + optimizer_cls = FlashSGD + elif self.cfg.optimizer == "flash_sgdw": + from flashoptim import FlashSGDW + + optimizer_cls = FlashSGDW + elif self.cfg.optimizer == "flash_lion": + from flashoptim import FlashLion + + optimizer_cls = FlashLion + if "betas" in adam_kwargs: + optimizer_kwargs["betas"] = adam_kwargs["betas"] else: raise ValueError( f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue." diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 792f6f6de..40fa314f4 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum): came_pytorch = "came_pytorch" muon = "muon" dion = "dion" + flash_adamw = "flash_adamw" + flash_adam = "flash_adam" + flash_sgd = "flash_sgd" + flash_sgdw = "flash_sgdw" + flash_lion = "flash_lion" class RingAttnFunc(str, Enum): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 5e6657a78..8ff61b370 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -790,6 +790,14 @@ class OptimizationValidationMixin: LOG.warning("adamw hyperparameters found, but no adamw optimizer set") return self + @staticmethod + def _resolve_fsdp_version(data): + """Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version.""" + fsdp_version = data.get("fsdp_version") + if fsdp_version is None: + fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + return fsdp_version + @model_validator(mode="before") @classmethod def check_muon_deepspeed_fsdp(cls, data): @@ -799,15 +807,32 @@ class OptimizationValidationMixin: "Muon optimizer is currently incompatible with DeepSpeed" ) if data.get("fsdp") or data.get("fsdp_config"): - fsdp_version = data.get("fsdp_version") - if fsdp_version is None: - fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + fsdp_version = cls._resolve_fsdp_version(data) if str(fsdp_version) != "2": raise ValueError( "Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP." ) return data + @model_validator(mode="before") + @classmethod + def check_flashoptim_deepspeed_fsdp(cls, data): + optimizer = data.get("optimizer") or "" + if str(optimizer).startswith("flash_"): + if data.get("deepspeed"): + raise ValueError( + f"{optimizer} optimizer is incompatible with DeepSpeed. " + "Flash optimizers only support DDP and FSDP2." + ) + if data.get("fsdp") or data.get("fsdp_config"): + fsdp_version = cls._resolve_fsdp_version(data) + if str(fsdp_version) != "2": + raise ValueError( + f"{optimizer} optimizer is only compatible with FSDP2. " + "Set fsdp_version: 2 to use flash optimizers with FSDP." + ) + return data + @model_validator(mode="before") @classmethod def check_batch_flattening_fa(cls, data): diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index de6c41fbe..40a536d4b 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -4,6 +4,8 @@ E2E tests for custom optimizers using Llama import unittest +import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -282,3 +284,59 @@ class TestCustomOptimizers(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + +@require_torch_2_7_0 +@pytest.mark.parametrize( + "optimizer_name,expected_class,learning_rate", + [ + ("flash_adamw", "FlashAdamW", 0.00001), + ("flash_adam", "FlashAdam", 0.00001), + ("flash_sgd", "FlashSGD", 0.01), + ("flash_sgdw", "FlashSGDW", 0.01), + ("flash_lion", "FlashLion", 0.0001), + ], +) +def test_flash_optimizers(tmp_path, optimizer_name, expected_class, learning_rate): + temp_dir = str(tmp_path) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": learning_rate, + "optimizer": optimizer_name, + "max_steps": 5, + "lr_scheduler": "cosine", + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert trainer.optimizer.optimizer.__class__.__name__ == expected_class