diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76fe..b3e97e3b2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,7 +710,13 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ - if self.cfg.flash_attention: + + if self.cfg.flex_attention: + self.model_kwargs["attn_implementation"] = "flex_attention" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flex_attention" + ) + elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass self.model_kwargs["attn_implementation"] = "flash_attention_2"