chore: refactor normalize_attn to use mapping and loop

This commit is contained in:
NanoCode012
2025-05-07 17:07:08 +07:00
parent d0c4930dd5
commit ef883b6960

View File

@@ -444,76 +444,48 @@ class AxolotlInputConfig(
@classmethod @classmethod
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
attention = data.get("attention") attention = data.get("attention")
# cases where attention is set and the specific *_attention is not set
if ( # Define mapping between enum values and flag names
not ( backend_mapping = {
data.get("flash_attention") AttentionBackend.eager: "eager_attention",
or data.get("flex_attention") AttentionBackend.flash: "flash_attention",
or data.get("eager_attention") AttentionBackend.flex: "flex_attention",
or data.get("s2_attention") AttentionBackend.s2: "s2_attention",
or data.get("sdp_attention") AttentionBackend.sdpa: "sdp_attention",
or data.get("xformers_attention") AttentionBackend.xformers: "xformers_attention",
) }
and attention
): # Check if any attention flag is set
if attention == AttentionBackend.eager: any_flag_set = any(
data["eager_attention"] = True data.get(flag_name) for flag_name in backend_mapping.values()
elif attention == AttentionBackend.flash: )
data["flash_attention"] = True
elif attention == AttentionBackend.flex: # CASE 1: attention is set but no flags are set - set the corresponding flag
data["flex_attention"] = True if attention and not any_flag_set:
elif attention == AttentionBackend.s2: flag_name = backend_mapping.get(attention)
data["s2_attention"] = True if flag_name:
elif attention == AttentionBackend.sdpa: data[flag_name] = True
data["sdp_attention"] = True
elif attention == AttentionBackend.xformers:
data["xformers_attention"] = True
return data return data
# CASE 2: no attention set but flags are set - set attention from flags
if not attention: if not attention:
LOG.warning( LOG.warning(
"*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended" "*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
) )
# attention should always be set since that's a requirement, defaults to None # Find the first True flag and set attention accordingly
if data.get("eager_attention"): for backend, flag_name in backend_mapping.items():
if not attention: if data.get(flag_name):
data["attention"] = AttentionBackend.eager data["attention"] = backend
return data return data
if attention != AttentionBackend.eager:
raise ValueError("attention mismatch with eager_attention already set") # CASE 3: both attention and flags are set - check for consistency
if data.get("flash_attention"): if attention:
if not attention: expected_flag = backend_mapping.get(attention)
data["attention"] = AttentionBackend.flash for backend, flag_name in backend_mapping.items():
return data # If a flag is set that doesn't match the attention value
if attention != AttentionBackend.flash: if data.get(flag_name) and flag_name != expected_flag:
raise ValueError("attention mismatch with flash_attention already set") raise ValueError(f"attention mismatch with {flag_name} already set")
if data.get("flex_attention"):
if not attention:
data["attention"] = AttentionBackend.flex
return data
if attention != AttentionBackend.flex:
raise ValueError("attention mismatch with flex_attention already set")
if data.get("s2_attention"):
if not attention:
data["attention"] = AttentionBackend.s2
return data
if attention != AttentionBackend.s2:
raise ValueError("attention mismatch with s2_attention already set")
if data.get("sdp_attention"):
if not attention:
data["attention"] = AttentionBackend.sdpa
return data
if attention != AttentionBackend.sdpa:
raise ValueError("attention mismatch with sdp_attention already set")
if data.get("xformers_attention"):
if not attention:
data["attention"] = AttentionBackend.xformers
return data
if attention != AttentionBackend.xformers:
raise ValueError(
"attention mismatch with xformers_attention already set"
)
return data return data