compute attn capability flags in normalizer instead of properties
This commit is contained in:
@@ -1335,39 +1335,29 @@ class AxolotlInputConfig(
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
# --- Attention capability properties ---
|
||||
# --- Attention capability flags (computed by normalize_attn_implementation) ---
|
||||
|
||||
@property
|
||||
def attn_supports_packing(self) -> bool:
|
||||
"""True if attention supports varlen sample packing via position_ids.
|
||||
|
||||
Known varlen backends: flash, flex, xformers, sage.
|
||||
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
|
||||
default to True since they generally support varlen.
|
||||
"""
|
||||
if not self.attn_implementation:
|
||||
return False
|
||||
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
|
||||
|
||||
@property
|
||||
def attn_uses_flash_lib(self) -> bool:
|
||||
"""True if the backend uses axolotl's flash_attn monkeypatches.
|
||||
|
||||
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
|
||||
HF-managed and don't need these patches.
|
||||
"""
|
||||
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
|
||||
|
||||
@property
|
||||
def attn_needs_dtype_cast(self) -> bool:
|
||||
"""True if attention needs embedding dtype cast to fp16/bf16.
|
||||
|
||||
Unknown backends (hub kernels) default to True (safe -- harmless
|
||||
if unnecessary, but missing cast causes errors).
|
||||
"""
|
||||
if not self.attn_implementation:
|
||||
return False
|
||||
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||
attn_supports_packing: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether the attention backend supports varlen sample packing. "
|
||||
"Computed automatically from attn_implementation."
|
||||
},
|
||||
)
|
||||
attn_uses_flash_lib: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether the attention backend requires axolotl's flash_attn "
|
||||
"monkeypatches. Computed automatically from attn_implementation."
|
||||
},
|
||||
)
|
||||
attn_needs_dtype_cast: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether the attention backend needs embedding dtype cast to "
|
||||
"fp16/bf16. Computed automatically from attn_implementation."
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1463,6 +1453,17 @@ class AxolotlInputConfig(
|
||||
)
|
||||
break
|
||||
|
||||
# Compute capability flags from the final attn_implementation value
|
||||
impl = data.get("attn_implementation")
|
||||
if impl:
|
||||
data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS
|
||||
data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS
|
||||
data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||
else:
|
||||
data["attn_supports_packing"] = False
|
||||
data["attn_uses_flash_lib"] = False
|
||||
data["attn_needs_dtype_cast"] = False
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user