diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 2affbee44..83be0288b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -233,7 +233,7 @@ class AxolotlInputConfig( fsdp: list[str] | None = None fsdp_config: dict[str, Any] | None = None fsdp_final_state_dict_type: Literal[ - "FULL_STATE_dict", "LOCAL_STATE_dict", "SHARDED_STATE_dict" + "FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT" ] | None = None val_set_size: float | None = Field(default=0.0) @@ -929,10 +929,10 @@ class AxolotlInputConfig( data.get("fsdp") and data.get("save_safetensors") and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_dict" + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" ): raise ValueError( - "FSDP SHARDED_STATE_dict not compatible with save_safetensors" + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" ) return data