compute attn capability flags in normalizer instead of properties
This commit is contained in:
@@ -1335,39 +1335,29 @@ class AxolotlInputConfig(
|
|||||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --- Attention capability properties ---
|
# --- Attention capability flags (computed by normalize_attn_implementation) ---
|
||||||
|
|
||||||
@property
|
attn_supports_packing: bool | None = Field(
|
||||||
def attn_supports_packing(self) -> bool:
|
default=None,
|
||||||
"""True if attention supports varlen sample packing via position_ids.
|
json_schema_extra={
|
||||||
|
"description": "Whether the attention backend supports varlen sample packing. "
|
||||||
Known varlen backends: flash, flex, xformers, sage.
|
"Computed automatically from attn_implementation."
|
||||||
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
|
},
|
||||||
default to True since they generally support varlen.
|
)
|
||||||
"""
|
attn_uses_flash_lib: bool | None = Field(
|
||||||
if not self.attn_implementation:
|
default=None,
|
||||||
return False
|
json_schema_extra={
|
||||||
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
|
"description": "Whether the attention backend requires axolotl's flash_attn "
|
||||||
|
"monkeypatches. Computed automatically from attn_implementation."
|
||||||
@property
|
},
|
||||||
def attn_uses_flash_lib(self) -> bool:
|
)
|
||||||
"""True if the backend uses axolotl's flash_attn monkeypatches.
|
attn_needs_dtype_cast: bool | None = Field(
|
||||||
|
default=None,
|
||||||
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
|
json_schema_extra={
|
||||||
HF-managed and don't need these patches.
|
"description": "Whether the attention backend needs embedding dtype cast to "
|
||||||
"""
|
"fp16/bf16. Computed automatically from attn_implementation."
|
||||||
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
|
},
|
||||||
|
)
|
||||||
@property
|
|
||||||
def attn_needs_dtype_cast(self) -> bool:
|
|
||||||
"""True if attention needs embedding dtype cast to fp16/bf16.
|
|
||||||
|
|
||||||
Unknown backends (hub kernels) default to True (safe -- harmless
|
|
||||||
if unnecessary, but missing cast causes errors).
|
|
||||||
"""
|
|
||||||
if not self.attn_implementation:
|
|
||||||
return False
|
|
||||||
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1463,6 +1453,17 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Compute capability flags from the final attn_implementation value
|
||||||
|
impl = data.get("attn_implementation")
|
||||||
|
if impl:
|
||||||
|
data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS
|
||||||
|
data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS
|
||||||
|
data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||||
|
else:
|
||||||
|
data["attn_supports_packing"] = False
|
||||||
|
data["attn_uses_flash_lib"] = False
|
||||||
|
data["attn_needs_dtype_cast"] = False
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
Reference in New Issue
Block a user