diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b3e97e3b2..d86c40ade 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -405,7 +405,7 @@ class ModelLoader: if ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and self.cfg.flash_attention + and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.sample_packing ): if "auto_map" in self.model_config: