fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check… (#3611)
* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR #3598 (commit b8358aa5).
* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR #3611
* fix: swap threading.local() for module-level store so autograd worker threads see shared_kv_states during backward recompute
Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:
PR 3611's TLS variant only works when recompute runs on the same thread
that set TLS during forward. PyTorch's C++ autograd engine
(_engine_run_backward) spawns per-device worker threads to dispatch
backward, and HF-Trainer gradient_checkpointing (stock
torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
unpack_hook -> recompute_fn on those worker threads. TLS set on the main
thread during forward is invisible there, so _get_shared_kv_states()
returns None and the consumer-layer lookup crashes with
"'NoneType' object is not subscriptable" at
fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).
A plain module-level dict is visible to all threads in the process.
Lifecycle is identical: the slot is overwritten each forward, releasing
the previous step's dict and allowing its K/V tensors to be GC'd, so
the original VRAM-leak fix still holds under FSDP2 AC too.
* scope gemma4 shared_kv_states side channel to checkpointed training
Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.
Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
* add gemma4 cross-thread visibility test for shared_kv_states store
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
* fix logger
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -395,7 +395,16 @@ class PatchManager:
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
# Shared-KV side channel when activation checkpointing (PR #3611).
|
||||
fsdp_cfg = self.cfg.fsdp_config
|
||||
needs_shared_kv_workaround = (not self.inference) and bool(
|
||||
self.cfg.gradient_checkpointing
|
||||
or self.cfg.activation_offloading
|
||||
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
|
||||
)
|
||||
patch_gemma4_fused_attn(
|
||||
install_shared_kv_workaround=needs_shared_kv_workaround
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fix_nemotron_h_conversion_mapping():
|
||||
|
||||
@@ -6,15 +6,29 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro
|
||||
|
||||
Usage:
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||
patch_gemma4_fused_attn()
|
||||
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
|
||||
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
|
||||
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
|
||||
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
|
||||
|
||||
|
||||
def _set_shared_kv_states(store):
|
||||
_GEMMA4_SHARED_KV_STORE["store"] = store
|
||||
|
||||
|
||||
def _get_shared_kv_states():
|
||||
return _GEMMA4_SHARED_KV_STORE["store"]
|
||||
|
||||
|
||||
def _make_fused_forward(original_forward):
|
||||
@@ -30,7 +44,7 @@ def _make_fused_forward(original_forward):
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
@@ -39,6 +53,10 @@ def _make_fused_forward(original_forward):
|
||||
eager_attention_forward,
|
||||
)
|
||||
|
||||
store = _get_shared_kv_states()
|
||||
if store is not None:
|
||||
shared_kv_states = store
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
eps = self.config.rms_norm_eps
|
||||
@@ -133,15 +151,44 @@ def _make_fused_forward(original_forward):
|
||||
return fused_forward
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn():
|
||||
def _patch_decoder_layer_call():
|
||||
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
|
||||
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
|
||||
|
||||
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
|
||||
return
|
||||
|
||||
original_call = Gemma4TextDecoderLayer.__call__
|
||||
|
||||
def patched_call(self, *args, **kwargs):
|
||||
shared_kv = kwargs.pop("shared_kv_states", None)
|
||||
# Overwrite unconditionally (including with None) so a previous step's
|
||||
# dict cannot leak into a later call without shared_kv_states (PR #3611).
|
||||
_set_shared_kv_states(shared_kv)
|
||||
return original_call(self, *args, **kwargs)
|
||||
|
||||
Gemma4TextDecoderLayer.__call__ = patched_call
|
||||
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
|
||||
and optionally route `shared_kv_states` via a module-level side channel to
|
||||
avoid a VRAM leak under activation checkpointing (PR #3611).
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
original_forward = Gemma4TextAttention.forward
|
||||
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||
|
||||
if install_shared_kv_workaround:
|
||||
_patch_decoder_layer_call()
|
||||
|
||||
logger.info(
|
||||
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||
)
|
||||
if install_shared_kv_workaround:
|
||||
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")
|
||||
|
||||
171
tests/monkeypatch/test_gemma4_fused_attn_patch.py
Normal file
171
tests/monkeypatch/test_gemma4_fused_attn_patch.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Unit tests for the Gemma4 fused-attention shared_kv_states routing patch."""
|
||||
|
||||
import pytest
|
||||
|
||||
gemma4_modeling = pytest.importorskip("transformers.models.gemma4.modeling_gemma4")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_decoder_layer_patch_slate():
|
||||
"""Save and restore Gemma4TextDecoderLayer.__call__ and the sentinel."""
|
||||
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||
|
||||
cls = gemma4_modeling.Gemma4TextDecoderLayer
|
||||
original_call = cls.__call__
|
||||
had_sentinel = getattr(cls, "_axolotl_shared_kv_patched", False)
|
||||
|
||||
if had_sentinel:
|
||||
del cls._axolotl_shared_kv_patched
|
||||
|
||||
try:
|
||||
yield cls, fused_attn
|
||||
finally:
|
||||
cls.__call__ = original_call
|
||||
if had_sentinel:
|
||||
cls._axolotl_shared_kv_patched = True
|
||||
elif hasattr(cls, "_axolotl_shared_kv_patched"):
|
||||
del cls._axolotl_shared_kv_patched
|
||||
fused_attn._set_shared_kv_states(None)
|
||||
|
||||
|
||||
class TestPatchedDecoderLayerCall:
|
||||
def test_pops_shared_kv_states_and_populates_store(
|
||||
self, clean_decoder_layer_patch_slate
|
||||
):
|
||||
cls, fused_attn = clean_decoder_layer_patch_slate
|
||||
|
||||
captured = {}
|
||||
|
||||
def spy(self, *args, **kwargs):
|
||||
captured["args"] = args
|
||||
captured["kwargs"] = dict(kwargs)
|
||||
return "spy_return"
|
||||
|
||||
cls.__call__ = spy
|
||||
fused_attn._patch_decoder_layer_call()
|
||||
|
||||
assert getattr(cls, "_axolotl_shared_kv_patched", False) is True
|
||||
assert cls.__call__ is not spy
|
||||
|
||||
shared_kv = {"layer_0": ("k", "v")}
|
||||
result = cls.__call__(
|
||||
object(),
|
||||
"positional_arg",
|
||||
shared_kv_states=shared_kv,
|
||||
other_kwarg="keep_me",
|
||||
)
|
||||
|
||||
assert result == "spy_return"
|
||||
assert captured["args"] == ("positional_arg",)
|
||||
assert "shared_kv_states" not in captured["kwargs"]
|
||||
assert captured["kwargs"] == {"other_kwarg": "keep_me"}
|
||||
assert fused_attn._get_shared_kv_states() is shared_kv
|
||||
|
||||
def test_clears_store_when_kwarg_absent(self, clean_decoder_layer_patch_slate):
|
||||
"""Regression for commit 251021e1: a prior step's dict must not leak
|
||||
into a later call that omits `shared_kv_states`."""
|
||||
cls, fused_attn = clean_decoder_layer_patch_slate
|
||||
|
||||
def spy(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
cls.__call__ = spy
|
||||
fused_attn._patch_decoder_layer_call()
|
||||
|
||||
stale = {"stale_step": True}
|
||||
fused_attn._set_shared_kv_states(stale)
|
||||
assert fused_attn._get_shared_kv_states() is stale
|
||||
|
||||
cls.__call__(object())
|
||||
|
||||
assert fused_attn._get_shared_kv_states() is None
|
||||
|
||||
def test_store_visible_across_threads(self):
|
||||
"""Regression for commit e3669b2c: the store must be readable from
|
||||
threads other than the one that set it. `threading.local()` failed
|
||||
this invariant, crashing with 'NoneType' object is not subscriptable'
|
||||
on MoE Gemma4 variants when autograd worker threads ran backward
|
||||
recompute under HF-Trainer gradient_checkpointing."""
|
||||
import threading
|
||||
|
||||
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||
|
||||
sentinel = {"layer_0": ("k", "v")}
|
||||
try:
|
||||
fused_attn._set_shared_kv_states(sentinel)
|
||||
|
||||
seen = {}
|
||||
|
||||
def worker():
|
||||
seen["value"] = fused_attn._get_shared_kv_states()
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
assert seen["value"] is sentinel
|
||||
finally:
|
||||
fused_attn._set_shared_kv_states(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_entry_point_patch_slate():
|
||||
"""Save and restore Gemma4TextAttention.forward and Gemma4TextDecoderLayer.__call__."""
|
||||
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||
|
||||
decoder_cls = gemma4_modeling.Gemma4TextDecoderLayer
|
||||
attn_cls = gemma4_modeling.Gemma4TextAttention
|
||||
|
||||
original_call = decoder_cls.__call__
|
||||
original_forward = attn_cls.forward
|
||||
had_sentinel = getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
||||
|
||||
if had_sentinel:
|
||||
del decoder_cls._axolotl_shared_kv_patched
|
||||
|
||||
try:
|
||||
yield decoder_cls, attn_cls, original_call, original_forward, fused_attn
|
||||
finally:
|
||||
decoder_cls.__call__ = original_call
|
||||
attn_cls.forward = original_forward
|
||||
if had_sentinel:
|
||||
decoder_cls._axolotl_shared_kv_patched = True
|
||||
elif hasattr(decoder_cls, "_axolotl_shared_kv_patched"):
|
||||
del decoder_cls._axolotl_shared_kv_patched
|
||||
fused_attn._set_shared_kv_states(None)
|
||||
|
||||
|
||||
class TestPatchGemma4FusedAttnEntryPoint:
|
||||
def test_default_flag_swaps_only_attention_forward(
|
||||
self, clean_entry_point_patch_slate
|
||||
):
|
||||
(
|
||||
decoder_cls,
|
||||
attn_cls,
|
||||
original_call,
|
||||
original_forward,
|
||||
fused_attn,
|
||||
) = clean_entry_point_patch_slate
|
||||
|
||||
fused_attn.patch_gemma4_fused_attn()
|
||||
|
||||
assert attn_cls.forward is not original_forward
|
||||
assert decoder_cls.__call__ is original_call
|
||||
assert not getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
||||
|
||||
def test_workaround_flag_installs_decoder_layer_patch(
|
||||
self, clean_entry_point_patch_slate
|
||||
):
|
||||
(
|
||||
decoder_cls,
|
||||
attn_cls,
|
||||
original_call,
|
||||
original_forward,
|
||||
fused_attn,
|
||||
) = clean_entry_point_patch_slate
|
||||
|
||||
fused_attn.patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||
|
||||
assert attn_cls.forward is not original_forward
|
||||
assert decoder_cls.__call__ is not original_call
|
||||
assert getattr(decoder_cls, "_axolotl_shared_kv_patched", False) is True
|
||||
Reference in New Issue
Block a user