Compare commits
1 Commits
fix/diffus
...
muon-valid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c578c8f256 |
@@ -1135,6 +1135,17 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
|
if data.get("optimizer") == "muon" and (
|
||||||
|
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
@@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizerValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
Test muon optimizer validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_muon_deepspeed(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"optimizer": "muon",
|
||||||
|
"deepspeed": "deepspeed_configs/zero3.json",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_muon_fsdp(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"optimizer": "muon",
|
||||||
|
"fsdp": ["full_shard"],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user