chore: refactor normalize_attn to use mapping and loop
This commit is contained in:
@@ -444,76 +444,48 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
|
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
|
||||||
attention = data.get("attention")
|
attention = data.get("attention")
|
||||||
# cases where attention is set and the specific *_attention is not set
|
|
||||||
if (
|
# Define mapping between enum values and flag names
|
||||||
not (
|
backend_mapping = {
|
||||||
data.get("flash_attention")
|
AttentionBackend.eager: "eager_attention",
|
||||||
or data.get("flex_attention")
|
AttentionBackend.flash: "flash_attention",
|
||||||
or data.get("eager_attention")
|
AttentionBackend.flex: "flex_attention",
|
||||||
or data.get("s2_attention")
|
AttentionBackend.s2: "s2_attention",
|
||||||
or data.get("sdp_attention")
|
AttentionBackend.sdpa: "sdp_attention",
|
||||||
or data.get("xformers_attention")
|
AttentionBackend.xformers: "xformers_attention",
|
||||||
)
|
}
|
||||||
and attention
|
|
||||||
):
|
# Check if any attention flag is set
|
||||||
if attention == AttentionBackend.eager:
|
any_flag_set = any(
|
||||||
data["eager_attention"] = True
|
data.get(flag_name) for flag_name in backend_mapping.values()
|
||||||
elif attention == AttentionBackend.flash:
|
)
|
||||||
data["flash_attention"] = True
|
|
||||||
elif attention == AttentionBackend.flex:
|
# CASE 1: attention is set but no flags are set - set the corresponding flag
|
||||||
data["flex_attention"] = True
|
if attention and not any_flag_set:
|
||||||
elif attention == AttentionBackend.s2:
|
flag_name = backend_mapping.get(attention)
|
||||||
data["s2_attention"] = True
|
if flag_name:
|
||||||
elif attention == AttentionBackend.sdpa:
|
data[flag_name] = True
|
||||||
data["sdp_attention"] = True
|
|
||||||
elif attention == AttentionBackend.xformers:
|
|
||||||
data["xformers_attention"] = True
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# CASE 2: no attention set but flags are set - set attention from flags
|
||||||
if not attention:
|
if not attention:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
|
"*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
|
||||||
)
|
)
|
||||||
|
|
||||||
# attention should always be set since that's a requirement, defaults to None
|
# Find the first True flag and set attention accordingly
|
||||||
if data.get("eager_attention"):
|
for backend, flag_name in backend_mapping.items():
|
||||||
if not attention:
|
if data.get(flag_name):
|
||||||
data["attention"] = AttentionBackend.eager
|
data["attention"] = backend
|
||||||
return data
|
return data
|
||||||
if attention != AttentionBackend.eager:
|
|
||||||
raise ValueError("attention mismatch with eager_attention already set")
|
# CASE 3: both attention and flags are set - check for consistency
|
||||||
if data.get("flash_attention"):
|
if attention:
|
||||||
if not attention:
|
expected_flag = backend_mapping.get(attention)
|
||||||
data["attention"] = AttentionBackend.flash
|
for backend, flag_name in backend_mapping.items():
|
||||||
return data
|
# If a flag is set that doesn't match the attention value
|
||||||
if attention != AttentionBackend.flash:
|
if data.get(flag_name) and flag_name != expected_flag:
|
||||||
raise ValueError("attention mismatch with flash_attention already set")
|
raise ValueError(f"attention mismatch with {flag_name} already set")
|
||||||
if data.get("flex_attention"):
|
|
||||||
if not attention:
|
|
||||||
data["attention"] = AttentionBackend.flex
|
|
||||||
return data
|
|
||||||
if attention != AttentionBackend.flex:
|
|
||||||
raise ValueError("attention mismatch with flex_attention already set")
|
|
||||||
if data.get("s2_attention"):
|
|
||||||
if not attention:
|
|
||||||
data["attention"] = AttentionBackend.s2
|
|
||||||
return data
|
|
||||||
if attention != AttentionBackend.s2:
|
|
||||||
raise ValueError("attention mismatch with s2_attention already set")
|
|
||||||
if data.get("sdp_attention"):
|
|
||||||
if not attention:
|
|
||||||
data["attention"] = AttentionBackend.sdpa
|
|
||||||
return data
|
|
||||||
if attention != AttentionBackend.sdpa:
|
|
||||||
raise ValueError("attention mismatch with sdp_attention already set")
|
|
||||||
if data.get("xformers_attention"):
|
|
||||||
if not attention:
|
|
||||||
data["attention"] = AttentionBackend.xformers
|
|
||||||
return data
|
|
||||||
if attention != AttentionBackend.xformers:
|
|
||||||
raise ValueError(
|
|
||||||
"attention mismatch with xformers_attention already set"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user