fix ops
This commit is contained in:
@@ -29,17 +29,9 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
compute_intensive_ops = [
|
compute_intensive_ops = [
|
||||||
aten.mm,
|
aten.mm.default,
|
||||||
aten.convolution,
|
aten.bmm.default,
|
||||||
aten.convolution_backward,
|
aten.addmm.default,
|
||||||
aten.bmm,
|
|
||||||
aten.addmm,
|
|
||||||
aten._scaled_dot_product_flash_attention,
|
|
||||||
aten._scaled_dot_product_efficient_attention,
|
|
||||||
aten._flash_attention_forward,
|
|
||||||
aten._efficient_attention_forward,
|
|
||||||
aten.upsample_bilinear2d,
|
|
||||||
aten._scaled_mm,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user