diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index babecfc7a..18d195f17 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -10,9 +10,9 @@ import transformers def patch_flex_wrapper(): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_52 = transformers.__version__ < "4.52.0" + is_transformers_below_4_51_1 = transformers.__version__ < "4.51.1" - if not (is_torch_2_6 and is_transformers_below_4_52): + if not (is_torch_2_6 and is_transformers_below_4_51_1): return from torch.nn.attention.flex_attention import flex_attention