diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 95dae7e20..57669936e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -556,6 +556,12 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: + if self.cfg.xformers_attention and self.cfg.sample_packing: + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + self.cfg.flash_attention = True + if self.cfg.chunked_cross_entropy: from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn