diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b9278068c..260cb2169 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1335,39 +1335,29 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None - # --- Attention capability properties --- + # --- Attention capability flags (computed by normalize_attn_implementation) --- - @property - def attn_supports_packing(self) -> bool: - """True if attention supports varlen sample packing via position_ids. - - Known varlen backends: flash, flex, xformers, sage. - Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3') - default to True since they generally support varlen. - """ - if not self.attn_implementation: - return False - return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS - - @property - def attn_uses_flash_lib(self) -> bool: - """True if the backend uses axolotl's flash_attn monkeypatches. - - Only for axolotl-managed FA setup (flash, s2). Hub kernels are - HF-managed and don't need these patches. - """ - return self.attn_implementation in FLASH_ATTN_LIB_IMPLS - - @property - def attn_needs_dtype_cast(self) -> bool: - """True if attention needs embedding dtype cast to fp16/bf16. - - Unknown backends (hub kernels) default to True (safe -- harmless - if unnecessary, but missing cast causes errors). - """ - if not self.attn_implementation: - return False - return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS + attn_supports_packing: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether the attention backend supports varlen sample packing. " + "Computed automatically from attn_implementation." + }, + ) + attn_uses_flash_lib: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether the attention backend requires axolotl's flash_attn " + "monkeypatches. Computed automatically from attn_implementation." + }, + ) + attn_needs_dtype_cast: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether the attention backend needs embedding dtype cast to " + "fp16/bf16. Computed automatically from attn_implementation." + }, + ) @model_validator(mode="before") @classmethod @@ -1463,6 +1453,17 @@ class AxolotlInputConfig( ) break + # Compute capability flags from the final attn_implementation value + impl = data.get("attn_implementation") + if impl: + data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS + data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS + data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS + else: + data["attn_supports_packing"] = False + data["attn_uses_flash_lib"] = False + data["attn_needs_dtype_cast"] = False + return data @model_validator(mode="before")