make sure to patch all the loaded models
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Flex attention monkey patch"""
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -52,9 +53,9 @@ def patch_flex_wrapper():
|
||||
|
||||
def patch_flex_make_mask():
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
||||
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
|
||||
|
||||
if not (is_torch_2_6 and is_transformers_below_4_51):
|
||||
if not (is_torch_2_6 and is_transformers_eq_4_51):
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import (
|
||||
@@ -66,7 +67,7 @@ def patch_flex_make_mask():
|
||||
|
||||
Offset = Union[torch.Tensor, int]
|
||||
|
||||
def make_flex_block_causal_mask(
|
||||
def patched_make_flex_block_causal_mask(
|
||||
attention_mask_2d: torch.Tensor,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
query_length=None,
|
||||
@@ -157,6 +158,14 @@ def patch_flex_make_mask():
|
||||
_compile=True,
|
||||
)
|
||||
|
||||
for n in tuple(sys.modules):
|
||||
if ".modeling_" in n and "llama4" not in n:
|
||||
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
||||
print(n)
|
||||
sys.modules[n].make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
)
|
||||
|
||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
||||
make_flex_block_causal_mask
|
||||
patched_make_flex_block_causal_mask
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user