fixes from PR feedback
This commit is contained in:
@@ -222,8 +222,8 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
attention: AttentionBackend = Field(
|
attention: AttentionBackend | None = Field(
|
||||||
default=AttentionBackend.flash,
|
default=None,
|
||||||
json_schema_extra={"description": "attention backend to use"},
|
json_schema_extra={"description": "attention backend to use"},
|
||||||
)
|
)
|
||||||
xformers_attention: bool | None = None
|
xformers_attention: bool | None = None
|
||||||
@@ -443,70 +443,77 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
|
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
|
||||||
# cases where both are set and already match
|
attention = data.get("attention")
|
||||||
if data.get("attention") == AttentionBackend.eager and data.get(
|
|
||||||
"eager_attention"
|
|
||||||
):
|
|
||||||
return data
|
|
||||||
if data.get("attention") == AttentionBackend.flash and data.get(
|
|
||||||
"flash_attention"
|
|
||||||
):
|
|
||||||
return data
|
|
||||||
if data.get("attention") == AttentionBackend.s2 and data.get("s2_attention"):
|
|
||||||
return data
|
|
||||||
if data.get("attention") == AttentionBackend.sdpa and data.get("sdp_attention"):
|
|
||||||
return data
|
|
||||||
if data.get("attention") == AttentionBackend.xformers and data.get(
|
|
||||||
"xformers_attention"
|
|
||||||
):
|
|
||||||
return data
|
|
||||||
|
|
||||||
# cases where attention is set and the specific *_attention is not set
|
# cases where attention is set and the specific *_attention is not set
|
||||||
if not (
|
if (
|
||||||
data.get("flash_attention")
|
not (
|
||||||
or data.get("eager_attention")
|
data.get("flash_attention")
|
||||||
or data.get("s2_attention")
|
or data.get("flex_attention")
|
||||||
or data.get("sdp_attention")
|
or data.get("eager_attention")
|
||||||
or data.get("xformers_attention")
|
or data.get("s2_attention")
|
||||||
|
or data.get("sdp_attention")
|
||||||
|
or data.get("xformers_attention")
|
||||||
|
)
|
||||||
|
and attention
|
||||||
):
|
):
|
||||||
if data.get("attention") == AttentionBackend.eager:
|
if attention == AttentionBackend.eager:
|
||||||
data["eager_attention"] = True
|
data["eager_attention"] = True
|
||||||
elif data.get("attention") == AttentionBackend.flash:
|
elif attention == AttentionBackend.flash:
|
||||||
data["flash_attention"] = True
|
data["flash_attention"] = True
|
||||||
elif data.get("attention") == AttentionBackend.s2:
|
elif attention == AttentionBackend.flex:
|
||||||
|
data["flex_attention"] = True
|
||||||
|
elif attention == AttentionBackend.s2:
|
||||||
data["s2_attention"] = True
|
data["s2_attention"] = True
|
||||||
elif data.get("attention") == AttentionBackend.sdpa:
|
elif attention == AttentionBackend.sdpa:
|
||||||
data["sdp_attention"] = True
|
data["sdp_attention"] = True
|
||||||
elif data.get("attention") == AttentionBackend.xformers:
|
elif attention == AttentionBackend.xformers:
|
||||||
data["xformers_attention"] = True
|
data["xformers_attention"] = True
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# attention should always be set since that's a requirement, defaults to flash
|
if not attention:
|
||||||
if (
|
LOG.warning(
|
||||||
data.get("eager_attention")
|
"*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
|
||||||
and not data.get("attention") == AttentionBackend.eager
|
)
|
||||||
):
|
|
||||||
raise ValueError("attention mismatch with eager_attention already set")
|
# attention should always be set since that's a requirement, defaults to None
|
||||||
if (
|
if data.get("eager_attention"):
|
||||||
data.get("flash_attention")
|
if not attention:
|
||||||
and not data.get("attention") == AttentionBackend.flash
|
data["attention"] = AttentionBackend.eager
|
||||||
):
|
return data
|
||||||
raise ValueError("attention mismatch with flash_attention already set")
|
if attention != AttentionBackend.eager:
|
||||||
if (
|
raise ValueError("attention mismatch with eager_attention already set")
|
||||||
data.get("s2_attention")
|
if data.get("flash_attention"):
|
||||||
and not data.get("attention") == AttentionBackend.s2
|
if not attention:
|
||||||
):
|
data["attention"] = AttentionBackend.flash
|
||||||
raise ValueError("attention mismatch with s2_attention already set")
|
return data
|
||||||
if (
|
if attention != AttentionBackend.flash:
|
||||||
data.get("sdp_attention")
|
raise ValueError("attention mismatch with flash_attention already set")
|
||||||
and not data.get("attention") == AttentionBackend.sdpa
|
if data.get("flex_attention"):
|
||||||
):
|
if not attention:
|
||||||
raise ValueError("attention mismatch with sdp_attention already set")
|
data["attention"] = AttentionBackend.flex
|
||||||
if (
|
return data
|
||||||
data.get("xformers_attention")
|
if attention != AttentionBackend.flex:
|
||||||
and not data.get("attention") == AttentionBackend.xformers
|
raise ValueError("attention mismatch with flex_attention already set")
|
||||||
):
|
if data.get("s2_attention"):
|
||||||
raise ValueError("attention mismatch with xformers_attention already set")
|
if not attention:
|
||||||
|
data["attention"] = AttentionBackend.s2
|
||||||
|
return data
|
||||||
|
if attention != AttentionBackend.s2:
|
||||||
|
raise ValueError("attention mismatch with s2_attention already set")
|
||||||
|
if data.get("sdp_attention"):
|
||||||
|
if not attention:
|
||||||
|
data["attention"] = AttentionBackend.sdpa
|
||||||
|
return data
|
||||||
|
if attention != AttentionBackend.sdpa:
|
||||||
|
raise ValueError("attention mismatch with sdp_attention already set")
|
||||||
|
if data.get("xformers_attention"):
|
||||||
|
if not attention:
|
||||||
|
data["attention"] = AttentionBackend.xformers
|
||||||
|
return data
|
||||||
|
if attention != AttentionBackend.xformers:
|
||||||
|
raise ValueError(
|
||||||
|
"attention mismatch with xformers_attention already set"
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user