From 04624c5a8d3f4b6c798fb9de8cfacbade0861bdc Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 7 Apr 2025 15:12:45 -0400 Subject: [PATCH] bump flex patching transformers to v4.51, update torch compile kwargs to be in line with transformers v4.51 --- src/axolotl/monkeypatch/attention/flex_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index d65ee706f..c643e2fd2 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -10,9 +10,9 @@ import transformers def patch_flex_wrapper(): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_51 = transformers.__version__ < "4.51.0" + is_transformers_below_4_52 = transformers.__version__ < "4.52.0" - if not (is_torch_2_6 and is_transformers_below_4_51): + if not (is_torch_2_6 and is_transformers_below_4_52): return from torch.nn.attention.flex_attention import flex_attention @@ -40,7 +40,7 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, - dynamic=False, + backend="inductor", mode="max-autotune-no-cudagraphs", fullgraph=True, )