diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6362bd2d7..24d668d79 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -222,8 +222,8 @@ class AxolotlInputConfig( }, ) - attention: AttentionBackend = Field( - default=AttentionBackend.flash, + attention: AttentionBackend | None = Field( + default=None, json_schema_extra={"description": "attention backend to use"}, ) xformers_attention: bool | None = None @@ -443,70 +443,77 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def normalize_attn(cls, data): # pylint: disable=too-many-return-statements - # cases where both are set and already match - if data.get("attention") == AttentionBackend.eager and data.get( - "eager_attention" - ): - return data - if data.get("attention") == AttentionBackend.flash and data.get( - "flash_attention" - ): - return data - if data.get("attention") == AttentionBackend.s2 and data.get("s2_attention"): - return data - if data.get("attention") == AttentionBackend.sdpa and data.get("sdp_attention"): - return data - if data.get("attention") == AttentionBackend.xformers and data.get( - "xformers_attention" - ): - return data - + attention = data.get("attention") # cases where attention is set and the specific *_attention is not set - if not ( - data.get("flash_attention") - or data.get("eager_attention") - or data.get("s2_attention") - or data.get("sdp_attention") - or data.get("xformers_attention") + if ( + not ( + data.get("flash_attention") + or data.get("flex_attention") + or data.get("eager_attention") + or data.get("s2_attention") + or data.get("sdp_attention") + or data.get("xformers_attention") + ) + and attention ): - if data.get("attention") == AttentionBackend.eager: + if attention == AttentionBackend.eager: data["eager_attention"] = True - elif data.get("attention") == AttentionBackend.flash: + elif attention == AttentionBackend.flash: data["flash_attention"] = True - elif data.get("attention") == AttentionBackend.s2: + elif attention == AttentionBackend.flex: + data["flex_attention"] = True + elif attention == AttentionBackend.s2: data["s2_attention"] = True - elif data.get("attention") == AttentionBackend.sdpa: + elif attention == AttentionBackend.sdpa: data["sdp_attention"] = True - elif data.get("attention") == AttentionBackend.xformers: + elif attention == AttentionBackend.xformers: data["xformers_attention"] = True return data - # attention should always be set since that's a requirement, defaults to flash - if ( - data.get("eager_attention") - and not data.get("attention") == AttentionBackend.eager - ): - raise ValueError("attention mismatch with eager_attention already set") - if ( - data.get("flash_attention") - and not data.get("attention") == AttentionBackend.flash - ): - raise ValueError("attention mismatch with flash_attention already set") - if ( - data.get("s2_attention") - and not data.get("attention") == AttentionBackend.s2 - ): - raise ValueError("attention mismatch with s2_attention already set") - if ( - data.get("sdp_attention") - and not data.get("attention") == AttentionBackend.sdpa - ): - raise ValueError("attention mismatch with sdp_attention already set") - if ( - data.get("xformers_attention") - and not data.get("attention") == AttentionBackend.xformers - ): - raise ValueError("attention mismatch with xformers_attention already set") + if not attention: + LOG.warning( + "*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended" + ) + + # attention should always be set since that's a requirement, defaults to None + if data.get("eager_attention"): + if not attention: + data["attention"] = AttentionBackend.eager + return data + if attention != AttentionBackend.eager: + raise ValueError("attention mismatch with eager_attention already set") + if data.get("flash_attention"): + if not attention: + data["attention"] = AttentionBackend.flash + return data + if attention != AttentionBackend.flash: + raise ValueError("attention mismatch with flash_attention already set") + if data.get("flex_attention"): + if not attention: + data["attention"] = AttentionBackend.flex + return data + if attention != AttentionBackend.flex: + raise ValueError("attention mismatch with flex_attention already set") + if data.get("s2_attention"): + if not attention: + data["attention"] = AttentionBackend.s2 + return data + if attention != AttentionBackend.s2: + raise ValueError("attention mismatch with s2_attention already set") + if data.get("sdp_attention"): + if not attention: + data["attention"] = AttentionBackend.sdpa + return data + if attention != AttentionBackend.sdpa: + raise ValueError("attention mismatch with sdp_attention already set") + if data.get("xformers_attention"): + if not attention: + data["attention"] = AttentionBackend.xformers + return data + if attention != AttentionBackend.xformers: + raise ValueError( + "attention mismatch with xformers_attention already set" + ) return data