fix ops
This commit is contained in:
@@ -29,17 +29,9 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
|
||||
aten = torch.ops.aten
|
||||
compute_intensive_ops = [
|
||||
aten.mm,
|
||||
aten.convolution,
|
||||
aten.convolution_backward,
|
||||
aten.bmm,
|
||||
aten.addmm,
|
||||
aten._scaled_dot_product_flash_attention,
|
||||
aten._scaled_dot_product_efficient_attention,
|
||||
aten._flash_attention_forward,
|
||||
aten._efficient_attention_forward,
|
||||
aten.upsample_bilinear2d,
|
||||
aten._scaled_mm,
|
||||
aten.mm.default,
|
||||
aten.bmm.default,
|
||||
aten.addmm.default,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user