WIP for activation offloading using streams and custom policy fn for checkpointing

This commit is contained in:
Wing Lian
2025-05-06 00:39:21 -04:00
parent 54960d4de0
commit b0cd54bcb9
2 changed files with 56 additions and 0 deletions

View File

@@ -610,3 +610,15 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def compute_loss_context_manager(self):
from contextlib import ExitStack
from torchtune.training import OffloadActivations
stack = ExitStack()
stack.enter_context(super().compute_loss_context_manager())
stack.enter_context(OffloadActivations())
return stack

View File

@@ -2,6 +2,13 @@
from functools import partial
import torch
from torch.utils.checkpoint import (
CheckpointPolicy,
checkpoint,
create_selective_checkpoint_contexts,
)
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -18,3 +25,40 @@ def hf_grad_checkpoint_offload_wrapper(
),
*args,
)
aten = torch.ops.aten
compute_intensive_ops = [
aten.mm,
aten.convolution,
aten.convolution_backward,
aten.bmm,
aten.addmm,
aten._scaled_dot_product_flash_attention,
aten._scaled_dot_product_efficient_attention,
aten._flash_attention_forward,
aten._efficient_attention_forward,
aten.upsample_bilinear2d,
aten._scaled_mm,
]
def policy_fn(ctx, op, *args, **kwargs):
if op in compute_intensive_ops:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
def checkpoint_w_policy(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return checkpoint(
decoder_layer,
*args,
use_reentrant=use_reentrant,
context_fn=context_fn,
)