diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index ee1e6cd9d..a36e56a68 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -29,17 +29,9 @@ def hf_grad_checkpoint_offload_wrapper( aten = torch.ops.aten compute_intensive_ops = [ - aten.mm, - aten.convolution, - aten.convolution_backward, - aten.bmm, - aten.addmm, - aten._scaled_dot_product_flash_attention, - aten._scaled_dot_product_efficient_attention, - aten._flash_attention_forward, - aten._efficient_attention_forward, - aten.upsample_bilinear2d, - aten._scaled_mm, + aten.mm.default, + aten.bmm.default, + aten.addmm.default, ]