diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml index 7fe4dd433..ca8e8e043 100644 --- a/examples/devstral/devstral-small-qlora.yml +++ b/examples/devstral/devstral-small-qlora.yml @@ -52,6 +52,7 @@ gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 flash_attention: true +scaling_softmax: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/ministral3/ministral3-3b-qlora.yaml b/examples/ministral3/ministral3-3b-qlora.yaml index a31545ab2..b369c9d41 100644 --- a/examples/ministral3/ministral3-3b-qlora.yaml +++ b/examples/ministral3/ministral3-3b-qlora.yaml @@ -59,6 +59,7 @@ gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 flash_attention: true +scaling_softmax: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 64f363bb1..b7a53c4d5 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -138,6 +138,7 @@ class PatchManager: self._apply_llama_flash_attn_patches(model) self._apply_unsloth_patches(model) self._apply_lora_kernel_patch(model) + self._apply_scaling_softmax_patch(model) def _apply_flash_attention_patches(self): """Apply patches related to Flash Attention.""" @@ -560,3 +561,16 @@ class PatchManager: ) patch_apertus_xielu_activation() + + def _apply_scaling_softmax_patch(self, model: PreTrainedModel): + """Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399""" + if self.cfg.scaling_softmax: + from axolotl.monkeypatch.scaled_softmax_attn import ( + patch_scaled_softmax_attention, + ) + + patch_scaled_softmax_attention( + scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43, + bias=self.cfg.scaling_softmax_bias or 0.0, + model=model, + ) diff --git a/src/axolotl/monkeypatch/scaled_softmax_attn.py b/src/axolotl/monkeypatch/scaled_softmax_attn.py new file mode 100644 index 000000000..bb2ebb8b8 --- /dev/null +++ b/src/axolotl/monkeypatch/scaled_softmax_attn.py @@ -0,0 +1,141 @@ +""" +Scaled Softmax (SSMax) attention patch using FlexAttention. +SSMax: softmax(scores * s * log(n) + b) where n is the position index +Ref: https://arxiv.org/abs/2501.19399 +""" + +import torch +from transformers import PreTrainedModel + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +try: + from torch.nn.attention.flex_attention import BlockMask + from transformers.integrations.flex_attention import ( + compile_friendly_flex_attention, + repeat_kv, + ) + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + +_ssmax_config = {} + + +def patch_scaled_softmax_attention( + scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None +): + """Patch attention to apply SSMax via FlexAttention score_mod.""" + global _ssmax_config + + if not FLEX_ATTENTION_AVAILABLE: + raise RuntimeError("SSMax requires FlexAttention.") + + _ssmax_config["ssmax_s"] = scaling_factor_init + _ssmax_config["ssmax_b"] = bias + + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + if "flex_attention" in ALL_ATTENTION_FUNCTIONS: + _ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"] + ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward + LOG.info( + f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})" + ) + else: + LOG.warning("flex_attention not found. Ensure flex_attention: true is set.") + + +def ssmax_flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask, + scaling: float | None = None, + softcap: float | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """FlexAttention forward with SSMax: score * (s * log(n) + b).""" + + if kwargs.get("dropout", 0.0) > 0: + raise ValueError("flex_attention does not support dropout") + + ssmax_s = _ssmax_config.get("ssmax_s", 0.43) + ssmax_b = _ssmax_config.get("ssmax_b", 0.0) + + position_ids = kwargs.get("position_ids", None) + position_ids_flat = position_ids.view(-1) if position_ids is not None else None + + block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None + score_mask = None if block_mask else attention_mask + + if score_mask is not None: + score_mask = score_mask[:, :, :, : key.shape[-2]] + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + """ + Apply SSMax scaling: score * (s * log(n) + b) + where n is the relative position within each packed sequence. + """ + if position_ids_flat is not None: + relative_pos = position_ids_flat[q_idx] + n = (relative_pos + 1).float() + else: + n = (q_idx + 1).float() + + n = torch.clamp(n, min=2.0) + + ssmax_scale = ssmax_s * torch.log(n) + ssmax_b + score = score * ssmax_scale + + if softcap is not None: + score = softcap * torch.tanh(score / softcap) + + if score_mask is not None: + score = score + score_mask[batch_idx][0][q_idx][kv_idx] + + return score + + enable_gqa = True + if (query.shape[1] & (query.shape[1] - 1)) != 0: + key = repeat_kv(key, query.shape[1] // key.shape[1]) + value = repeat_kv(value, query.shape[1] // value.shape[1]) + enable_gqa = False + + return_lse = query.device.type != "cpu" + flex_output = compile_friendly_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask, + enable_gqa=enable_gqa, + scale=scaling, + kernel_options=kwargs.get("kernel_options"), + return_lse=return_lse, + training=module.training, + ) + + if return_lse: + attention_output, lse = flex_output + lse = lse.to(value.dtype) + else: + attention_output, lse = flex_output, None + + return attention_output.transpose(1, 2).contiguous(), lse + + +def unpatch_scaled_softmax_attention(): + """Restore the original FlexAttention function.""" + global _ssmax_config + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + if "original_flex_fn" in _ssmax_config: + ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"] + _ssmax_config.clear() + LOG.info("Unpatched flex_attention, restored original") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4ef1aff3a..da21df7aa 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -619,6 +619,25 @@ class AxolotlInputConfig( }, ) + scaling_softmax: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399" + }, + ) + scaling_softmax_factor: float | None = Field( + default=None, + json_schema_extra={ + "description": "Scaling factor for SSMax attention. Default is 0.43" + }, + ) + scaling_softmax_bias: float | None = Field( + default=None, + json_schema_extra={ + "description": "Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better length generalization." + }, + ) + unsloth_cross_entropy_loss: bool | None = None unsloth_lora_mlp: bool | None = None unsloth_lora_qkv: bool | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index cb834a3bf..bf054d353 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -201,6 +201,16 @@ class AttentionValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_scaling_softmax_requires_flex(cls, data): + if data.get("scaling_softmax") and not data.get("flex_attention"): + raise ValueError( + "scaling_softmax requires flex_attention: true\n" + "Add 'flex_attention: true' to your config file.\n" + ) + return data + class TrainingValidationMixin: """Validation methods related to training configuration."""