diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b3e97e3b2..1cc374514 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -387,13 +387,12 @@ class ModelLoader: self.patch_attention() if self.cfg.model_config_type == "llama": - from axolotl.monkeypatch.trainer_grad_accum import ( - patch_flash_attention_forward, + from axolotl.monkeypatch.trainer_grad_accum import ( # patch_flash_attention_forward, patch_forward_for_ga, patch_training_step_for_ga, ) - patch_flash_attention_forward() + # patch_flash_attention_forward() patch_forward_for_ga() patch_training_step_for_ga()