diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index d65ee706f..c643e2fd2 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -10,9 +10,9 @@ import transformers def patch_flex_wrapper(): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_51 = transformers.__version__ < "4.51.0" + is_transformers_below_4_52 = transformers.__version__ < "4.52.0" - if not (is_torch_2_6 and is_transformers_below_4_51): + if not (is_torch_2_6 and is_transformers_below_4_52): return from torch.nn.attention.flex_attention import flex_attention @@ -40,7 +40,7 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, - dynamic=False, + backend="inductor", mode="max-autotune-no-cudagraphs", fullgraph=True, )