From c578c8f256cc706ca02193afd73b0004349f4b09 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 1 Apr 2025 09:29:54 -0400 Subject: [PATCH] Validation for Muon optimizer with DS/FSDP --- src/axolotl/utils/schemas/config.py | 11 +++++++ tests/test_validation_dataset.py | 45 +++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3e072b6b4..1ec9e296d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1135,6 +1135,17 @@ class AxolotlInputConfig( return value + @model_validator(mode="before") + @classmethod + def check_muon_deepspeed_fsdp(cls, data): + if data.get("optimizer") == "muon" and ( + data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") + ): + raise ValueError( + "Muon optimizer is currently incompatible with DeepSpeed and FSDP" + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 47d10ee99..ba142f3bf 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation): ) validate_config(cfg) + + +class TestOptimizerValidation(BaseValidation): + """ + Test muon optimizer validation + """ + + def test_muon_deepspeed(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "optimizer": "muon", + "deepspeed": "deepspeed_configs/zero3.json", + } + ) + + with pytest.raises(ValueError, match=r".*is currently incompatible with*"): + validate_config(cfg) + + def test_muon_fsdp(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "optimizer": "muon", + "fsdp": ["full_shard"], + "fsdp_config": { + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + }, + } + ) + + with pytest.raises(ValueError, match=r".*is currently incompatible with*"): + validate_config(cfg)