diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f32b7c12e..873965516 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -395,7 +395,16 @@ class PatchManager: patch_gemma4_fused_attn, ) - patch_gemma4_fused_attn() + # Shared-KV side channel when activation checkpointing (PR #3611). + fsdp_cfg = self.cfg.fsdp_config + needs_shared_kv_workaround = (not self.inference) and bool( + self.cfg.gradient_checkpointing + or self.cfg.activation_offloading + or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing) + ) + patch_gemma4_fused_attn( + install_shared_kv_workaround=needs_shared_kv_workaround + ) @staticmethod def _fix_nemotron_h_conversion_mapping(): diff --git a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py index 7cb5c6beb..2144b6c41 100644 --- a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py +++ b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py @@ -6,15 +6,29 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro Usage: from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn - patch_gemma4_fused_attn() + # Pass install_shared_kv_workaround=True when activation checkpointing is enabled. + patch_gemma4_fused_attn(install_shared_kv_workaround=True) """ -import logging from typing import Callable import torch -logger = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +logger = get_logger(__name__) + +# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS +# to prevent memory leak on gradient checkpoint enabled training (PR #3611) +_GEMMA4_SHARED_KV_STORE: dict = {"store": None} + + +def _set_shared_kv_states(store): + _GEMMA4_SHARED_KV_STORE["store"] = store + + +def _get_shared_kv_states(): + return _GEMMA4_SHARED_KV_STORE["store"] def _make_fused_forward(original_forward): @@ -30,7 +44,7 @@ def _make_fused_forward(original_forward): hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: torch.Tensor | None, - shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], + shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None, past_key_values=None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -39,6 +53,10 @@ def _make_fused_forward(original_forward): eager_attention_forward, ) + store = _get_shared_kv_states() + if store is not None: + shared_kv_states = store + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) eps = self.config.rms_norm_eps @@ -133,15 +151,44 @@ def _make_fused_forward(original_forward): return fused_forward -def patch_gemma4_fused_attn(): +def _patch_decoder_layer_call(): + """Strip `shared_kv_states` from decoder-layer kwargs and route via the + module-level side channel so the checkpoint partial cannot pin it (PR #3611). """ - Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels. + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer + + if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False): + return + + original_call = Gemma4TextDecoderLayer.__call__ + + def patched_call(self, *args, **kwargs): + shared_kv = kwargs.pop("shared_kv_states", None) + # Overwrite unconditionally (including with None) so a previous step's + # dict cannot leak into a later call without shared_kv_states (PR #3611). + _set_shared_kv_states(shared_kv) + return original_call(self, *args, **kwargs) + + Gemma4TextDecoderLayer.__call__ = patched_call + Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True + + +def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False): + """ + Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels, + and optionally route `shared_kv_states` via a module-level side channel to + avoid a VRAM leak under activation checkpointing (PR #3611). """ from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention original_forward = Gemma4TextAttention.forward Gemma4TextAttention.forward = _make_fused_forward(original_forward) + if install_shared_kv_workaround: + _patch_decoder_layer_call() + logger.info( "Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels" ) + if install_shared_kv_workaround: + logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)") diff --git a/tests/monkeypatch/test_gemma4_fused_attn_patch.py b/tests/monkeypatch/test_gemma4_fused_attn_patch.py new file mode 100644 index 000000000..75dbe472b --- /dev/null +++ b/tests/monkeypatch/test_gemma4_fused_attn_patch.py @@ -0,0 +1,171 @@ +"""Unit tests for the Gemma4 fused-attention shared_kv_states routing patch.""" + +import pytest + +gemma4_modeling = pytest.importorskip("transformers.models.gemma4.modeling_gemma4") + + +@pytest.fixture +def clean_decoder_layer_patch_slate(): + """Save and restore Gemma4TextDecoderLayer.__call__ and the sentinel.""" + from axolotl.monkeypatch.models.gemma4 import fused_attn + + cls = gemma4_modeling.Gemma4TextDecoderLayer + original_call = cls.__call__ + had_sentinel = getattr(cls, "_axolotl_shared_kv_patched", False) + + if had_sentinel: + del cls._axolotl_shared_kv_patched + + try: + yield cls, fused_attn + finally: + cls.__call__ = original_call + if had_sentinel: + cls._axolotl_shared_kv_patched = True + elif hasattr(cls, "_axolotl_shared_kv_patched"): + del cls._axolotl_shared_kv_patched + fused_attn._set_shared_kv_states(None) + + +class TestPatchedDecoderLayerCall: + def test_pops_shared_kv_states_and_populates_store( + self, clean_decoder_layer_patch_slate + ): + cls, fused_attn = clean_decoder_layer_patch_slate + + captured = {} + + def spy(self, *args, **kwargs): + captured["args"] = args + captured["kwargs"] = dict(kwargs) + return "spy_return" + + cls.__call__ = spy + fused_attn._patch_decoder_layer_call() + + assert getattr(cls, "_axolotl_shared_kv_patched", False) is True + assert cls.__call__ is not spy + + shared_kv = {"layer_0": ("k", "v")} + result = cls.__call__( + object(), + "positional_arg", + shared_kv_states=shared_kv, + other_kwarg="keep_me", + ) + + assert result == "spy_return" + assert captured["args"] == ("positional_arg",) + assert "shared_kv_states" not in captured["kwargs"] + assert captured["kwargs"] == {"other_kwarg": "keep_me"} + assert fused_attn._get_shared_kv_states() is shared_kv + + def test_clears_store_when_kwarg_absent(self, clean_decoder_layer_patch_slate): + """Regression for commit 251021e1: a prior step's dict must not leak + into a later call that omits `shared_kv_states`.""" + cls, fused_attn = clean_decoder_layer_patch_slate + + def spy(self, *args, **kwargs): + return None + + cls.__call__ = spy + fused_attn._patch_decoder_layer_call() + + stale = {"stale_step": True} + fused_attn._set_shared_kv_states(stale) + assert fused_attn._get_shared_kv_states() is stale + + cls.__call__(object()) + + assert fused_attn._get_shared_kv_states() is None + + def test_store_visible_across_threads(self): + """Regression for commit e3669b2c: the store must be readable from + threads other than the one that set it. `threading.local()` failed + this invariant, crashing with 'NoneType' object is not subscriptable' + on MoE Gemma4 variants when autograd worker threads ran backward + recompute under HF-Trainer gradient_checkpointing.""" + import threading + + from axolotl.monkeypatch.models.gemma4 import fused_attn + + sentinel = {"layer_0": ("k", "v")} + try: + fused_attn._set_shared_kv_states(sentinel) + + seen = {} + + def worker(): + seen["value"] = fused_attn._get_shared_kv_states() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert seen["value"] is sentinel + finally: + fused_attn._set_shared_kv_states(None) + + +@pytest.fixture +def clean_entry_point_patch_slate(): + """Save and restore Gemma4TextAttention.forward and Gemma4TextDecoderLayer.__call__.""" + from axolotl.monkeypatch.models.gemma4 import fused_attn + + decoder_cls = gemma4_modeling.Gemma4TextDecoderLayer + attn_cls = gemma4_modeling.Gemma4TextAttention + + original_call = decoder_cls.__call__ + original_forward = attn_cls.forward + had_sentinel = getattr(decoder_cls, "_axolotl_shared_kv_patched", False) + + if had_sentinel: + del decoder_cls._axolotl_shared_kv_patched + + try: + yield decoder_cls, attn_cls, original_call, original_forward, fused_attn + finally: + decoder_cls.__call__ = original_call + attn_cls.forward = original_forward + if had_sentinel: + decoder_cls._axolotl_shared_kv_patched = True + elif hasattr(decoder_cls, "_axolotl_shared_kv_patched"): + del decoder_cls._axolotl_shared_kv_patched + fused_attn._set_shared_kv_states(None) + + +class TestPatchGemma4FusedAttnEntryPoint: + def test_default_flag_swaps_only_attention_forward( + self, clean_entry_point_patch_slate + ): + ( + decoder_cls, + attn_cls, + original_call, + original_forward, + fused_attn, + ) = clean_entry_point_patch_slate + + fused_attn.patch_gemma4_fused_attn() + + assert attn_cls.forward is not original_forward + assert decoder_cls.__call__ is original_call + assert not getattr(decoder_cls, "_axolotl_shared_kv_patched", False) + + def test_workaround_flag_installs_decoder_layer_patch( + self, clean_entry_point_patch_slate + ): + ( + decoder_cls, + attn_cls, + original_call, + original_forward, + fused_attn, + ) = clean_entry_point_patch_slate + + fused_attn.patch_gemma4_fused_attn(install_shared_kv_workaround=True) + + assert attn_cls.forward is not original_forward + assert decoder_cls.__call__ is not original_call + assert getattr(decoder_cls, "_axolotl_shared_kv_patched", False) is True