diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 24d668d79..78b0d909e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -444,76 +444,48 @@ class AxolotlInputConfig( @classmethod def normalize_attn(cls, data): # pylint: disable=too-many-return-statements attention = data.get("attention") - # cases where attention is set and the specific *_attention is not set - if ( - not ( - data.get("flash_attention") - or data.get("flex_attention") - or data.get("eager_attention") - or data.get("s2_attention") - or data.get("sdp_attention") - or data.get("xformers_attention") - ) - and attention - ): - if attention == AttentionBackend.eager: - data["eager_attention"] = True - elif attention == AttentionBackend.flash: - data["flash_attention"] = True - elif attention == AttentionBackend.flex: - data["flex_attention"] = True - elif attention == AttentionBackend.s2: - data["s2_attention"] = True - elif attention == AttentionBackend.sdpa: - data["sdp_attention"] = True - elif attention == AttentionBackend.xformers: - data["xformers_attention"] = True + + # Define mapping between enum values and flag names + backend_mapping = { + AttentionBackend.eager: "eager_attention", + AttentionBackend.flash: "flash_attention", + AttentionBackend.flex: "flex_attention", + AttentionBackend.s2: "s2_attention", + AttentionBackend.sdpa: "sdp_attention", + AttentionBackend.xformers: "xformers_attention", + } + + # Check if any attention flag is set + any_flag_set = any( + data.get(flag_name) for flag_name in backend_mapping.values() + ) + + # CASE 1: attention is set but no flags are set - set the corresponding flag + if attention and not any_flag_set: + flag_name = backend_mapping.get(attention) + if flag_name: + data[flag_name] = True return data + # CASE 2: no attention set but flags are set - set attention from flags if not attention: LOG.warning( "*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended" ) - # attention should always be set since that's a requirement, defaults to None - if data.get("eager_attention"): - if not attention: - data["attention"] = AttentionBackend.eager - return data - if attention != AttentionBackend.eager: - raise ValueError("attention mismatch with eager_attention already set") - if data.get("flash_attention"): - if not attention: - data["attention"] = AttentionBackend.flash - return data - if attention != AttentionBackend.flash: - raise ValueError("attention mismatch with flash_attention already set") - if data.get("flex_attention"): - if not attention: - data["attention"] = AttentionBackend.flex - return data - if attention != AttentionBackend.flex: - raise ValueError("attention mismatch with flex_attention already set") - if data.get("s2_attention"): - if not attention: - data["attention"] = AttentionBackend.s2 - return data - if attention != AttentionBackend.s2: - raise ValueError("attention mismatch with s2_attention already set") - if data.get("sdp_attention"): - if not attention: - data["attention"] = AttentionBackend.sdpa - return data - if attention != AttentionBackend.sdpa: - raise ValueError("attention mismatch with sdp_attention already set") - if data.get("xformers_attention"): - if not attention: - data["attention"] = AttentionBackend.xformers - return data - if attention != AttentionBackend.xformers: - raise ValueError( - "attention mismatch with xformers_attention already set" - ) + # Find the first True flag and set attention accordingly + for backend, flag_name in backend_mapping.items(): + if data.get(flag_name): + data["attention"] = backend + return data + + # CASE 3: both attention and flags are set - check for consistency + if attention: + expected_flag = backend_mapping.get(attention) + for backend, flag_name in backend_mapping.items(): + # If a flag is set that doesn't match the attention value + if data.get(flag_name) and flag_name != expected_flag: + raise ValueError(f"attention mismatch with {flag_name} already set") return data