From 7610a02881c7f361c0a661bb257da315bf4c0fc3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 May 2025 01:00:02 -0400 Subject: [PATCH] fix ops --- .../utils/gradient_checkpointing/__init__.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index ee1e6cd9d..a36e56a68 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -29,17 +29,9 @@ def hf_grad_checkpoint_offload_wrapper( aten = torch.ops.aten compute_intensive_ops = [ - aten.mm, - aten.convolution, - aten.convolution_backward, - aten.bmm, - aten.addmm, - aten._scaled_dot_product_flash_attention, - aten._scaled_dot_product_efficient_attention, - aten._flash_attention_forward, - aten._efficient_attention_forward, - aten.upsample_bilinear2d, - aten._scaled_mm, + aten.mm.default, + aten.bmm.default, + aten.addmm.default, ]