diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1feb8aae8..10899d1f9 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1402,6 +1402,16 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_fsdp_grad_accum_4_46_2(cls, data): + if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1: + if version("transformers") == "4.46.2": + raise ValueError( + "FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. " + "Please use a lower value or switch to an older version of transformers." + ) + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options"""