Validation for Muon optimizer with DS/FSDP
This commit is contained in:
@@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestOptimizerValidation(BaseValidation):
|
||||
"""
|
||||
Test muon optimizer validation
|
||||
"""
|
||||
|
||||
def test_muon_deepspeed(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"deepspeed": "deepspeed_configs/zero3.json",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_muon_fsdp(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"fsdp": ["full_shard"],
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user