Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -26,32 +26,6 @@ class TestFSDPValidation:
|
||||
assert cfg.fsdp_version == 2
|
||||
assert cfg.fsdp_config.fsdp_version is None
|
||||
|
||||
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
# test w/o prefix too
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
|
||||
Reference in New Issue
Block a user