From 1a5d4454130e75ac1bf8fb15ae02bd00b9105966 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Apr 2025 14:45:30 -0400 Subject: [PATCH] make sure to patch all the loaded models --- src/axolotl/monkeypatch/attention/flex_attn.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 8ca6d06b0..4098c8c1c 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,5 +1,6 @@ """Flex attention monkey patch""" +import sys from typing import Optional, Tuple, Union import torch @@ -52,9 +53,9 @@ def patch_flex_wrapper(): def patch_flex_make_mask(): is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_51 = transformers.__version__ < "4.51.0" + is_transformers_eq_4_51 = transformers.__version__ == "4.51.0" - if not (is_torch_2_6 and is_transformers_below_4_51): + if not (is_torch_2_6 and is_transformers_eq_4_51): return from torch.nn.attention.flex_attention import ( @@ -66,7 +67,7 @@ def patch_flex_make_mask(): Offset = Union[torch.Tensor, int] - def make_flex_block_causal_mask( + def patched_make_flex_block_causal_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, query_length=None, @@ -157,6 +158,14 @@ def patch_flex_make_mask(): _compile=True, ) + for n in tuple(sys.modules): + if ".modeling_" in n and "llama4" not in n: + if hasattr(sys.modules[n], "make_flex_block_causal_mask"): + print(n) + sys.modules[n].make_flex_block_causal_mask = ( + patched_make_flex_block_causal_mask + ) + transformers.integrations.flex_attention.make_flex_block_causal_mask = ( - make_flex_block_causal_mask + patched_make_flex_block_causal_mask )