make sure to patch all the loaded models
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Flex attention monkey patch"""
|
"""Flex attention monkey patch"""
|
||||||
|
|
||||||
|
import sys
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -52,9 +53,9 @@ def patch_flex_wrapper():
|
|||||||
|
|
||||||
def patch_flex_make_mask():
|
def patch_flex_make_mask():
|
||||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||||
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
|
||||||
|
|
||||||
if not (is_torch_2_6 and is_transformers_below_4_51):
|
if not (is_torch_2_6 and is_transformers_eq_4_51):
|
||||||
return
|
return
|
||||||
|
|
||||||
from torch.nn.attention.flex_attention import (
|
from torch.nn.attention.flex_attention import (
|
||||||
@@ -66,7 +67,7 @@ def patch_flex_make_mask():
|
|||||||
|
|
||||||
Offset = Union[torch.Tensor, int]
|
Offset = Union[torch.Tensor, int]
|
||||||
|
|
||||||
def make_flex_block_causal_mask(
|
def patched_make_flex_block_causal_mask(
|
||||||
attention_mask_2d: torch.Tensor,
|
attention_mask_2d: torch.Tensor,
|
||||||
attention_chunk_size: Optional[int] = None,
|
attention_chunk_size: Optional[int] = None,
|
||||||
query_length=None,
|
query_length=None,
|
||||||
@@ -157,6 +158,14 @@ def patch_flex_make_mask():
|
|||||||
_compile=True,
|
_compile=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for n in tuple(sys.modules):
|
||||||
|
if ".modeling_" in n and "llama4" not in n:
|
||||||
|
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
||||||
|
print(n)
|
||||||
|
sys.modules[n].make_flex_block_causal_mask = (
|
||||||
|
patched_make_flex_block_causal_mask
|
||||||
|
)
|
||||||
|
|
||||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
||||||
make_flex_block_causal_mask
|
patched_make_flex_block_causal_mask
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user