Compare commits
1 Commits
model-load
...
muon-valid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c578c8f256 |
@@ -1135,6 +1135,17 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
if data.get("optimizer") == "muon" and (
|
||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||
):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestOptimizerValidation(BaseValidation):
|
||||
"""
|
||||
Test muon optimizer validation
|
||||
"""
|
||||
|
||||
def test_muon_deepspeed(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"deepspeed": "deepspeed_configs/zero3.json",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_muon_fsdp(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"fsdp": ["full_shard"],
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user