compute attn capability flags in normalizer instead of properties

This commit is contained in:
Wing Lian
2026-04-12 23:46:44 -04:00
parent ff5d6393c8
commit 35d43fe141

View File

@@ -1335,39 +1335,29 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
# --- Attention capability properties ---
# --- Attention capability flags (computed by normalize_attn_implementation) ---
@property
def attn_supports_packing(self) -> bool:
"""True if attention supports varlen sample packing via position_ids.
Known varlen backends: flash, flex, xformers, sage.
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
default to True since they generally support varlen.
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
@property
def attn_uses_flash_lib(self) -> bool:
"""True if the backend uses axolotl's flash_attn monkeypatches.
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
HF-managed and don't need these patches.
"""
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
@property
def attn_needs_dtype_cast(self) -> bool:
"""True if attention needs embedding dtype cast to fp16/bf16.
Unknown backends (hub kernels) default to True (safe -- harmless
if unnecessary, but missing cast causes errors).
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
attn_supports_packing: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend supports varlen sample packing. "
"Computed automatically from attn_implementation."
},
)
attn_uses_flash_lib: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend requires axolotl's flash_attn "
"monkeypatches. Computed automatically from attn_implementation."
},
)
attn_needs_dtype_cast: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend needs embedding dtype cast to "
"fp16/bf16. Computed automatically from attn_implementation."
},
)
@model_validator(mode="before")
@classmethod
@@ -1463,6 +1453,17 @@ class AxolotlInputConfig(
)
break
# Compute capability flags from the final attn_implementation value
impl = data.get("attn_implementation")
if impl:
data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS
data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS
data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS
else:
data["attn_supports_packing"] = False
data["attn_uses_flash_lib"] = False
data["attn_needs_dtype_cast"] = False
return data
@model_validator(mode="before")