diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index c643e2fd2..58e0c8e89 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -40,7 +40,6 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, - backend="inductor", mode="max-autotune-no-cudagraphs", fullgraph=True, )