* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR #3598 (commit b8358aa5).
* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR #3611
* fix: swap threading.local() for module-level store so autograd worker threads see shared_kv_states during backward recompute
Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:
PR 3611's TLS variant only works when recompute runs on the same thread
that set TLS during forward. PyTorch's C++ autograd engine
(_engine_run_backward) spawns per-device worker threads to dispatch
backward, and HF-Trainer gradient_checkpointing (stock
torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
unpack_hook -> recompute_fn on those worker threads. TLS set on the main
thread during forward is invisible there, so _get_shared_kv_states()
returns None and the consumer-layer lookup crashes with
"'NoneType' object is not subscriptable" at
fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).
A plain module-level dict is visible to all threads in the process.
Lifecycle is identical: the slot is overwritten each forward, releasing
the previous step's dict and allowing its K/V tensors to be GC'd, so
the original VRAM-leak fix still holds under FSDP2 AC too.
* scope gemma4 shared_kv_states side channel to checkpointed training
Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.
Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
* add gemma4 cross-thread visibility test for shared_kv_states store
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
* fix logger
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
172 lines
5.6 KiB
Python
172 lines
5.6 KiB
Python
"""Unit tests for the Gemma4 fused-attention shared_kv_states routing patch."""
|
|
|
|
import pytest
|
|
|
|
gemma4_modeling = pytest.importorskip("transformers.models.gemma4.modeling_gemma4")
|
|
|
|
|
|
@pytest.fixture
|
|
def clean_decoder_layer_patch_slate():
|
|
"""Save and restore Gemma4TextDecoderLayer.__call__ and the sentinel."""
|
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
|
|
|
cls = gemma4_modeling.Gemma4TextDecoderLayer
|
|
original_call = cls.__call__
|
|
had_sentinel = getattr(cls, "_axolotl_shared_kv_patched", False)
|
|
|
|
if had_sentinel:
|
|
del cls._axolotl_shared_kv_patched
|
|
|
|
try:
|
|
yield cls, fused_attn
|
|
finally:
|
|
cls.__call__ = original_call
|
|
if had_sentinel:
|
|
cls._axolotl_shared_kv_patched = True
|
|
elif hasattr(cls, "_axolotl_shared_kv_patched"):
|
|
del cls._axolotl_shared_kv_patched
|
|
fused_attn._set_shared_kv_states(None)
|
|
|
|
|
|
class TestPatchedDecoderLayerCall:
|
|
def test_pops_shared_kv_states_and_populates_store(
|
|
self, clean_decoder_layer_patch_slate
|
|
):
|
|
cls, fused_attn = clean_decoder_layer_patch_slate
|
|
|
|
captured = {}
|
|
|
|
def spy(self, *args, **kwargs):
|
|
captured["args"] = args
|
|
captured["kwargs"] = dict(kwargs)
|
|
return "spy_return"
|
|
|
|
cls.__call__ = spy
|
|
fused_attn._patch_decoder_layer_call()
|
|
|
|
assert getattr(cls, "_axolotl_shared_kv_patched", False) is True
|
|
assert cls.__call__ is not spy
|
|
|
|
shared_kv = {"layer_0": ("k", "v")}
|
|
result = cls.__call__(
|
|
object(),
|
|
"positional_arg",
|
|
shared_kv_states=shared_kv,
|
|
other_kwarg="keep_me",
|
|
)
|
|
|
|
assert result == "spy_return"
|
|
assert captured["args"] == ("positional_arg",)
|
|
assert "shared_kv_states" not in captured["kwargs"]
|
|
assert captured["kwargs"] == {"other_kwarg": "keep_me"}
|
|
assert fused_attn._get_shared_kv_states() is shared_kv
|
|
|
|
def test_clears_store_when_kwarg_absent(self, clean_decoder_layer_patch_slate):
|
|
"""Regression for commit 251021e1: a prior step's dict must not leak
|
|
into a later call that omits `shared_kv_states`."""
|
|
cls, fused_attn = clean_decoder_layer_patch_slate
|
|
|
|
def spy(self, *args, **kwargs):
|
|
return None
|
|
|
|
cls.__call__ = spy
|
|
fused_attn._patch_decoder_layer_call()
|
|
|
|
stale = {"stale_step": True}
|
|
fused_attn._set_shared_kv_states(stale)
|
|
assert fused_attn._get_shared_kv_states() is stale
|
|
|
|
cls.__call__(object())
|
|
|
|
assert fused_attn._get_shared_kv_states() is None
|
|
|
|
def test_store_visible_across_threads(self):
|
|
"""Regression for commit e3669b2c: the store must be readable from
|
|
threads other than the one that set it. `threading.local()` failed
|
|
this invariant, crashing with 'NoneType' object is not subscriptable'
|
|
on MoE Gemma4 variants when autograd worker threads ran backward
|
|
recompute under HF-Trainer gradient_checkpointing."""
|
|
import threading
|
|
|
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
|
|
|
sentinel = {"layer_0": ("k", "v")}
|
|
try:
|
|
fused_attn._set_shared_kv_states(sentinel)
|
|
|
|
seen = {}
|
|
|
|
def worker():
|
|
seen["value"] = fused_attn._get_shared_kv_states()
|
|
|
|
t = threading.Thread(target=worker)
|
|
t.start()
|
|
t.join()
|
|
|
|
assert seen["value"] is sentinel
|
|
finally:
|
|
fused_attn._set_shared_kv_states(None)
|
|
|
|
|
|
@pytest.fixture
|
|
def clean_entry_point_patch_slate():
|
|
"""Save and restore Gemma4TextAttention.forward and Gemma4TextDecoderLayer.__call__."""
|
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
|
|
|
decoder_cls = gemma4_modeling.Gemma4TextDecoderLayer
|
|
attn_cls = gemma4_modeling.Gemma4TextAttention
|
|
|
|
original_call = decoder_cls.__call__
|
|
original_forward = attn_cls.forward
|
|
had_sentinel = getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
|
|
|
if had_sentinel:
|
|
del decoder_cls._axolotl_shared_kv_patched
|
|
|
|
try:
|
|
yield decoder_cls, attn_cls, original_call, original_forward, fused_attn
|
|
finally:
|
|
decoder_cls.__call__ = original_call
|
|
attn_cls.forward = original_forward
|
|
if had_sentinel:
|
|
decoder_cls._axolotl_shared_kv_patched = True
|
|
elif hasattr(decoder_cls, "_axolotl_shared_kv_patched"):
|
|
del decoder_cls._axolotl_shared_kv_patched
|
|
fused_attn._set_shared_kv_states(None)
|
|
|
|
|
|
class TestPatchGemma4FusedAttnEntryPoint:
|
|
def test_default_flag_swaps_only_attention_forward(
|
|
self, clean_entry_point_patch_slate
|
|
):
|
|
(
|
|
decoder_cls,
|
|
attn_cls,
|
|
original_call,
|
|
original_forward,
|
|
fused_attn,
|
|
) = clean_entry_point_patch_slate
|
|
|
|
fused_attn.patch_gemma4_fused_attn()
|
|
|
|
assert attn_cls.forward is not original_forward
|
|
assert decoder_cls.__call__ is original_call
|
|
assert not getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
|
|
|
def test_workaround_flag_installs_decoder_layer_patch(
|
|
self, clean_entry_point_patch_slate
|
|
):
|
|
(
|
|
decoder_cls,
|
|
attn_cls,
|
|
original_call,
|
|
original_forward,
|
|
fused_attn,
|
|
) = clean_entry_point_patch_slate
|
|
|
|
fused_attn.patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
|
|
|
assert attn_cls.forward is not original_forward
|
|
assert decoder_cls.__call__ is not original_call
|
|
assert getattr(decoder_cls, "_axolotl_shared_kv_patched", False) is True
|