diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c6798..77597ae1a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -717,6 +717,7 @@ class AxolotlInputConfig( xformers_attention: Optional[bool] = None sdp_attention: Optional[bool] = None s2_attention: Optional[bool] = None + flex_attention: Optional[bool] = None flash_attention: Optional[bool] = None flash_attn_cross_entropy: Optional[bool] = None flash_attn_rms_norm: Optional[bool] = None @@ -1611,6 +1612,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data + @model_validator(mode="before") + @classmethod + def check_flex_torch_version(cls, data): + if (data.get("flex_attention") is not None) and ( + data.get("flex_attention") is True + ): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if version.parse(torch_version) < version.parse("2.5.1"): + raise ValueError( + "Flex attention is not supported on torch version < 2.5.1" + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_auto(cls, data):