make sure to patch all the loaded models

This commit is contained in:
Wing Lian
2025-04-06 14:45:30 -04:00
parent 7e410ab480
commit 1a5d445413

View File

@@ -1,5 +1,6 @@
"""Flex attention monkey patch""" """Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@@ -52,9 +53,9 @@ def patch_flex_wrapper():
def patch_flex_make_mask(): def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6") is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0" is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51): if not (is_torch_2_6 and is_transformers_eq_4_51):
return return
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
@@ -66,7 +67,7 @@ def patch_flex_make_mask():
Offset = Union[torch.Tensor, int] Offset = Union[torch.Tensor, int]
def make_flex_block_causal_mask( def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor, attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None, attention_chunk_size: Optional[int] = None,
query_length=None, query_length=None,
@@ -157,6 +158,14 @@ def patch_flex_make_mask():
_compile=True, _compile=True,
) )
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
print(n)
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = ( transformers.integrations.flex_attention.make_flex_block_causal_mask = (
make_flex_block_causal_mask patched_make_flex_block_causal_mask
) )