fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check… (#3611)

* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing

Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.

HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.

Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.

Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.

Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
  * threading.local() store with _get/_set_shared_kv_states helpers
  * _patch_decoder_layer_call(): monkeypatches
    Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
    and stash it in TLS before delegating to GradientCheckpointingLayer.
    The partial formed downstream no longer references the dict.
  * fused_forward reads TLS first, falls back to kwarg for callers that
    bypass the patched __call__ (e.g. direct attention invocation).
  * wired into patch_gemma4_fused_attn; idempotent via a sentinel.

TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.

Introduced by PR #3598 (commit b8358aa5).

* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__

  Overwrite the thread-local shared_kv_states store on every invocation
  (including with None) instead of only when the kwarg is present.

  The previous conditional write left stale dicts in TLS on any path that
  reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
  kwarg — e.g. generation, eval hooks, or future HF refactors that make
  the kwarg optional. fused_forward would then silently consume a prior
  step's K/V dict instead of falling back to its own kwarg path.

  Unconditional write makes the invariant in the surrounding comment
  ("TLS is overwritten on each new step's first decoder-layer call, so
  the previous step's dict is released promptly") actually hold.

  No behavior change for the training happy path, which always passes
  the kwarg. Addresses CodeRabbit review on PR #3611

* fix: swap threading.local() for module-level store so autograd worker   threads see shared_kv_states during backward recompute

Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:

PR 3611's TLS variant only works when recompute runs on the same thread
  that set TLS during forward. PyTorch's C++ autograd engine
  (_engine_run_backward) spawns per-device worker threads to dispatch
  backward, and HF-Trainer gradient_checkpointing (stock
  torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
  unpack_hook -> recompute_fn on those worker threads. TLS set on the main
  thread during forward is invisible there, so _get_shared_kv_states()
  returns None and the consumer-layer lookup crashes with
  "'NoneType' object is not subscriptable" at
  fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).

  A plain module-level dict is visible to all threads in the process.
  Lifecycle is identical: the slot is overwritten each forward, releasing
  the previous step's dict and allowing its K/V tensors to be GC'd, so
  the original VRAM-leak fix still holds under FSDP2 AC too.

* scope gemma4 shared_kv_states side channel to checkpointed training

Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.

Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments

* add gemma4 cross-thread visibility test for shared_kv_states store

Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error

* fix logger

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
thad0ctor
2026-04-21 14:49:58 -07:00
committed by GitHub
parent 9de5b76336
commit e562e149ce
3 changed files with 234 additions and 7 deletions

View File

@@ -395,7 +395,16 @@ class PatchManager:
patch_gemma4_fused_attn,
)
patch_gemma4_fused_attn()
# Shared-KV side channel when activation checkpointing (PR #3611).
fsdp_cfg = self.cfg.fsdp_config
needs_shared_kv_workaround = (not self.inference) and bool(
self.cfg.gradient_checkpointing
or self.cfg.activation_offloading
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
)
patch_gemma4_fused_attn(
install_shared_kv_workaround=needs_shared_kv_workaround
)
@staticmethod
def _fix_nemotron_h_conversion_mapping():

View File

@@ -6,15 +6,29 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
patch_gemma4_fused_attn()
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
"""
import logging
from typing import Callable
import torch
logger = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
logger = get_logger(__name__)
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
def _set_shared_kv_states(store):
_GEMMA4_SHARED_KV_STORE["store"] = store
def _get_shared_kv_states():
return _GEMMA4_SHARED_KV_STORE["store"]
def _make_fused_forward(original_forward):
@@ -30,7 +44,7 @@ def _make_fused_forward(original_forward):
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
@@ -39,6 +53,10 @@ def _make_fused_forward(original_forward):
eager_attention_forward,
)
store = _get_shared_kv_states()
if store is not None:
shared_kv_states = store
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
@@ -133,15 +151,44 @@ def _make_fused_forward(original_forward):
return fused_forward
def patch_gemma4_fused_attn():
def _patch_decoder_layer_call():
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
return
original_call = Gemma4TextDecoderLayer.__call__
def patched_call(self, *args, **kwargs):
shared_kv = kwargs.pop("shared_kv_states", None)
# Overwrite unconditionally (including with None) so a previous step's
# dict cannot leak into a later call without shared_kv_states (PR #3611).
_set_shared_kv_states(shared_kv)
return original_call(self, *args, **kwargs)
Gemma4TextDecoderLayer.__call__ = patched_call
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
and optionally route `shared_kv_states` via a module-level side channel to
avoid a VRAM leak under activation checkpointing (PR #3611).
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
if install_shared_kv_workaround:
_patch_decoder_layer_call()
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)
if install_shared_kv_workaround:
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")

View File

@@ -0,0 +1,171 @@
"""Unit tests for the Gemma4 fused-attention shared_kv_states routing patch."""
import pytest
gemma4_modeling = pytest.importorskip("transformers.models.gemma4.modeling_gemma4")
@pytest.fixture
def clean_decoder_layer_patch_slate():
"""Save and restore Gemma4TextDecoderLayer.__call__ and the sentinel."""
from axolotl.monkeypatch.models.gemma4 import fused_attn
cls = gemma4_modeling.Gemma4TextDecoderLayer
original_call = cls.__call__
had_sentinel = getattr(cls, "_axolotl_shared_kv_patched", False)
if had_sentinel:
del cls._axolotl_shared_kv_patched
try:
yield cls, fused_attn
finally:
cls.__call__ = original_call
if had_sentinel:
cls._axolotl_shared_kv_patched = True
elif hasattr(cls, "_axolotl_shared_kv_patched"):
del cls._axolotl_shared_kv_patched
fused_attn._set_shared_kv_states(None)
class TestPatchedDecoderLayerCall:
def test_pops_shared_kv_states_and_populates_store(
self, clean_decoder_layer_patch_slate
):
cls, fused_attn = clean_decoder_layer_patch_slate
captured = {}
def spy(self, *args, **kwargs):
captured["args"] = args
captured["kwargs"] = dict(kwargs)
return "spy_return"
cls.__call__ = spy
fused_attn._patch_decoder_layer_call()
assert getattr(cls, "_axolotl_shared_kv_patched", False) is True
assert cls.__call__ is not spy
shared_kv = {"layer_0": ("k", "v")}
result = cls.__call__(
object(),
"positional_arg",
shared_kv_states=shared_kv,
other_kwarg="keep_me",
)
assert result == "spy_return"
assert captured["args"] == ("positional_arg",)
assert "shared_kv_states" not in captured["kwargs"]
assert captured["kwargs"] == {"other_kwarg": "keep_me"}
assert fused_attn._get_shared_kv_states() is shared_kv
def test_clears_store_when_kwarg_absent(self, clean_decoder_layer_patch_slate):
"""Regression for commit 251021e1: a prior step's dict must not leak
into a later call that omits `shared_kv_states`."""
cls, fused_attn = clean_decoder_layer_patch_slate
def spy(self, *args, **kwargs):
return None
cls.__call__ = spy
fused_attn._patch_decoder_layer_call()
stale = {"stale_step": True}
fused_attn._set_shared_kv_states(stale)
assert fused_attn._get_shared_kv_states() is stale
cls.__call__(object())
assert fused_attn._get_shared_kv_states() is None
def test_store_visible_across_threads(self):
"""Regression for commit e3669b2c: the store must be readable from
threads other than the one that set it. `threading.local()` failed
this invariant, crashing with 'NoneType' object is not subscriptable'
on MoE Gemma4 variants when autograd worker threads ran backward
recompute under HF-Trainer gradient_checkpointing."""
import threading
from axolotl.monkeypatch.models.gemma4 import fused_attn
sentinel = {"layer_0": ("k", "v")}
try:
fused_attn._set_shared_kv_states(sentinel)
seen = {}
def worker():
seen["value"] = fused_attn._get_shared_kv_states()
t = threading.Thread(target=worker)
t.start()
t.join()
assert seen["value"] is sentinel
finally:
fused_attn._set_shared_kv_states(None)
@pytest.fixture
def clean_entry_point_patch_slate():
"""Save and restore Gemma4TextAttention.forward and Gemma4TextDecoderLayer.__call__."""
from axolotl.monkeypatch.models.gemma4 import fused_attn
decoder_cls = gemma4_modeling.Gemma4TextDecoderLayer
attn_cls = gemma4_modeling.Gemma4TextAttention
original_call = decoder_cls.__call__
original_forward = attn_cls.forward
had_sentinel = getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
if had_sentinel:
del decoder_cls._axolotl_shared_kv_patched
try:
yield decoder_cls, attn_cls, original_call, original_forward, fused_attn
finally:
decoder_cls.__call__ = original_call
attn_cls.forward = original_forward
if had_sentinel:
decoder_cls._axolotl_shared_kv_patched = True
elif hasattr(decoder_cls, "_axolotl_shared_kv_patched"):
del decoder_cls._axolotl_shared_kv_patched
fused_attn._set_shared_kv_states(None)
class TestPatchGemma4FusedAttnEntryPoint:
def test_default_flag_swaps_only_attention_forward(
self, clean_entry_point_patch_slate
):
(
decoder_cls,
attn_cls,
original_call,
original_forward,
fused_attn,
) = clean_entry_point_patch_slate
fused_attn.patch_gemma4_fused_attn()
assert attn_cls.forward is not original_forward
assert decoder_cls.__call__ is original_call
assert not getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
def test_workaround_flag_installs_decoder_layer_patch(
self, clean_entry_point_patch_slate
):
(
decoder_cls,
attn_cls,
original_call,
original_forward,
fused_attn,
) = clean_entry_point_patch_slate
fused_attn.patch_gemma4_fused_attn(install_shared_kv_workaround=True)
assert attn_cls.forward is not original_forward
assert decoder_cls.__call__ is not original_call
assert getattr(decoder_cls, "_axolotl_shared_kv_patched", False) is True