fixes from PR feedback

This commit is contained in:
Wing Lian
2025-04-27 20:11:11 -04:00
committed by NanoCode012
parent ba47adc24b
commit 6ee7cb30fa

View File

@@ -222,8 +222,8 @@ class AxolotlInputConfig(
}, },
) )
attention: AttentionBackend = Field( attention: AttentionBackend | None = Field(
default=AttentionBackend.flash, default=None,
json_schema_extra={"description": "attention backend to use"}, json_schema_extra={"description": "attention backend to use"},
) )
xformers_attention: bool | None = None xformers_attention: bool | None = None
@@ -443,70 +443,77 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
# cases where both are set and already match attention = data.get("attention")
if data.get("attention") == AttentionBackend.eager and data.get(
"eager_attention"
):
return data
if data.get("attention") == AttentionBackend.flash and data.get(
"flash_attention"
):
return data
if data.get("attention") == AttentionBackend.s2 and data.get("s2_attention"):
return data
if data.get("attention") == AttentionBackend.sdpa and data.get("sdp_attention"):
return data
if data.get("attention") == AttentionBackend.xformers and data.get(
"xformers_attention"
):
return data
# cases where attention is set and the specific *_attention is not set # cases where attention is set and the specific *_attention is not set
if not ( if (
data.get("flash_attention") not (
or data.get("eager_attention") data.get("flash_attention")
or data.get("s2_attention") or data.get("flex_attention")
or data.get("sdp_attention") or data.get("eager_attention")
or data.get("xformers_attention") or data.get("s2_attention")
or data.get("sdp_attention")
or data.get("xformers_attention")
)
and attention
): ):
if data.get("attention") == AttentionBackend.eager: if attention == AttentionBackend.eager:
data["eager_attention"] = True data["eager_attention"] = True
elif data.get("attention") == AttentionBackend.flash: elif attention == AttentionBackend.flash:
data["flash_attention"] = True data["flash_attention"] = True
elif data.get("attention") == AttentionBackend.s2: elif attention == AttentionBackend.flex:
data["flex_attention"] = True
elif attention == AttentionBackend.s2:
data["s2_attention"] = True data["s2_attention"] = True
elif data.get("attention") == AttentionBackend.sdpa: elif attention == AttentionBackend.sdpa:
data["sdp_attention"] = True data["sdp_attention"] = True
elif data.get("attention") == AttentionBackend.xformers: elif attention == AttentionBackend.xformers:
data["xformers_attention"] = True data["xformers_attention"] = True
return data return data
# attention should always be set since that's a requirement, defaults to flash if not attention:
if ( LOG.warning(
data.get("eager_attention") "*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
and not data.get("attention") == AttentionBackend.eager )
):
raise ValueError("attention mismatch with eager_attention already set") # attention should always be set since that's a requirement, defaults to None
if ( if data.get("eager_attention"):
data.get("flash_attention") if not attention:
and not data.get("attention") == AttentionBackend.flash data["attention"] = AttentionBackend.eager
): return data
raise ValueError("attention mismatch with flash_attention already set") if attention != AttentionBackend.eager:
if ( raise ValueError("attention mismatch with eager_attention already set")
data.get("s2_attention") if data.get("flash_attention"):
and not data.get("attention") == AttentionBackend.s2 if not attention:
): data["attention"] = AttentionBackend.flash
raise ValueError("attention mismatch with s2_attention already set") return data
if ( if attention != AttentionBackend.flash:
data.get("sdp_attention") raise ValueError("attention mismatch with flash_attention already set")
and not data.get("attention") == AttentionBackend.sdpa if data.get("flex_attention"):
): if not attention:
raise ValueError("attention mismatch with sdp_attention already set") data["attention"] = AttentionBackend.flex
if ( return data
data.get("xformers_attention") if attention != AttentionBackend.flex:
and not data.get("attention") == AttentionBackend.xformers raise ValueError("attention mismatch with flex_attention already set")
): if data.get("s2_attention"):
raise ValueError("attention mismatch with xformers_attention already set") if not attention:
data["attention"] = AttentionBackend.s2
return data
if attention != AttentionBackend.s2:
raise ValueError("attention mismatch with s2_attention already set")
if data.get("sdp_attention"):
if not attention:
data["attention"] = AttentionBackend.sdpa
return data
if attention != AttentionBackend.sdpa:
raise ValueError("attention mismatch with sdp_attention already set")
if data.get("xformers_attention"):
if not attention:
data["attention"] = AttentionBackend.xformers
return data
if attention != AttentionBackend.xformers:
raise ValueError(
"attention mismatch with xformers_attention already set"
)
return data return data