make sure to patch all the loaded models

This commit is contained in:
Wing Lian
2025-04-06 14:45:30 -04:00
parent 7e410ab480
commit 1a5d445413

View File

@@ -1,5 +1,6 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
import torch
@@ -52,9 +53,9 @@ def patch_flex_wrapper():
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51):
if not (is_torch_2_6 and is_transformers_eq_4_51):
return
from torch.nn.attention.flex_attention import (
@@ -66,7 +67,7 @@ def patch_flex_make_mask():
Offset = Union[torch.Tensor, int]
def make_flex_block_causal_mask(
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
@@ -157,6 +158,14 @@ def patch_flex_make_mask():
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
print(n)
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
make_flex_block_causal_mask
patched_make_flex_block_causal_mask
)