diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 58e0c8e89..babecfc7a 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -40,6 +40,7 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, + dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True, )