feat : scaled softmax support (#3338)

* scaled softmax

* comment

* lint

* remove egear

* validation for flash

* lint

* val imporve + neet

* fix correct softmax scale val(learned)

* learned scale val 4 ssm

* lint

* fix model_type rmv

* sdpa_atten

* test fix + lint

* test fix

* sdp_a val rmv

* flex fix

* main flash

* lint

* flex attn

* lint comment

* fix score_mod

* Update src/axolotl/utils/schemas/validation.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2026-01-13 13:03:11 +05:30
committed by GitHub
parent 3e0bbd33ec
commit 258ce8d4fa
6 changed files with 186 additions and 0 deletions

View File

@@ -52,6 +52,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -59,6 +59,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -138,6 +138,7 @@ class PatchManager:
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
@@ -560,3 +561,16 @@ class PatchManager:
)
patch_apertus_xielu_activation()
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:
from axolotl.monkeypatch.scaled_softmax_attn import (
patch_scaled_softmax_attention,
)
patch_scaled_softmax_attention(
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
bias=self.cfg.scaling_softmax_bias or 0.0,
model=model,
)

View File

@@ -0,0 +1,141 @@
"""
Scaled Softmax (SSMax) attention patch using FlexAttention.
SSMax: softmax(scores * s * log(n) + b) where n is the position index
Ref: https://arxiv.org/abs/2501.19399
"""
import torch
from transformers import PreTrainedModel
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from torch.nn.attention.flex_attention import BlockMask
from transformers.integrations.flex_attention import (
compile_friendly_flex_attention,
repeat_kv,
)
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
_ssmax_config = {}
def patch_scaled_softmax_attention(
scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None
):
"""Patch attention to apply SSMax via FlexAttention score_mod."""
global _ssmax_config
if not FLEX_ATTENTION_AVAILABLE:
raise RuntimeError("SSMax requires FlexAttention.")
_ssmax_config["ssmax_s"] = scaling_factor_init
_ssmax_config["ssmax_b"] = bias
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "flex_attention" in ALL_ATTENTION_FUNCTIONS:
_ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"]
ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward
LOG.info(
f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})"
)
else:
LOG.warning("flex_attention not found. Ensure flex_attention: true is set.")
def ssmax_flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask,
scaling: float | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""FlexAttention forward with SSMax: score * (s * log(n) + b)."""
if kwargs.get("dropout", 0.0) > 0:
raise ValueError("flex_attention does not support dropout")
ssmax_s = _ssmax_config.get("ssmax_s", 0.43)
ssmax_b = _ssmax_config.get("ssmax_b", 0.0)
position_ids = kwargs.get("position_ids", None)
position_ids_flat = position_ids.view(-1) if position_ids is not None else None
block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None
score_mask = None if block_mask else attention_mask
if score_mask is not None:
score_mask = score_mask[:, :, :, : key.shape[-2]]
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
"""
Apply SSMax scaling: score * (s * log(n) + b)
where n is the relative position within each packed sequence.
"""
if position_ids_flat is not None:
relative_pos = position_ids_flat[q_idx]
n = (relative_pos + 1).float()
else:
n = (q_idx + 1).float()
n = torch.clamp(n, min=2.0)
ssmax_scale = ssmax_s * torch.log(n) + ssmax_b
score = score * ssmax_scale
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if score_mask is not None:
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
return score
enable_gqa = True
if (query.shape[1] & (query.shape[1] - 1)) != 0:
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
return_lse = query.device.type != "cpu"
flex_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kwargs.get("kernel_options"),
return_lse=return_lse,
training=module.training,
)
if return_lse:
attention_output, lse = flex_output
lse = lse.to(value.dtype)
else:
attention_output, lse = flex_output, None
return attention_output.transpose(1, 2).contiguous(), lse
def unpatch_scaled_softmax_attention():
"""Restore the original FlexAttention function."""
global _ssmax_config
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "original_flex_fn" in _ssmax_config:
ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"]
_ssmax_config.clear()
LOG.info("Unpatched flex_attention, restored original")

View File

@@ -619,6 +619,25 @@ class AxolotlInputConfig(
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399"
},
)
scaling_softmax_factor: float | None = Field(
default=None,
json_schema_extra={
"description": "Scaling factor for SSMax attention. Default is 0.43"
},
)
scaling_softmax_bias: float | None = Field(
default=None,
json_schema_extra={
"description": "Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better length generalization."
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None

View File

@@ -201,6 +201,16 @@ class AttentionValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not data.get("flex_attention"):
raise ValueError(
"scaling_softmax requires flex_attention: true\n"
"Add 'flex_attention: true' to your config file.\n"
)
return data
class TrainingValidationMixin:
"""Validation methods related to training configuration."""