From b0cd54bcb97b9b4d8c81600e3e294ec6dc72fdb3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 May 2025 00:39:21 -0400 Subject: [PATCH] WIP for activation offloading using streams and custom policy fn for checkpointing --- src/axolotl/core/trainers/base.py | 12 +++++ .../utils/gradient_checkpointing/__init__.py | 44 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3864903a5..cd8b7d819 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -610,3 +610,15 @@ class AxolotlTrainer( output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) + + def compute_loss_context_manager(self): + from contextlib import ExitStack + + from torchtune.training import OffloadActivations + + stack = ExitStack() + + stack.enter_context(super().compute_loss_context_manager()) + stack.enter_context(OffloadActivations()) + + return stack diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index 0da5c83a2..ee1e6cd9d 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -2,6 +2,13 @@ from functools import partial +import torch +from torch.utils.checkpoint import ( + CheckpointPolicy, + checkpoint, + create_selective_checkpoint_contexts, +) + from axolotl.utils.gradient_checkpointing.unsloth import ( Unsloth_Offloaded_Gradient_Checkpointer, ) @@ -18,3 +25,40 @@ def hf_grad_checkpoint_offload_wrapper( ), *args, ) + + +aten = torch.ops.aten +compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten._flash_attention_forward, + aten._efficient_attention_forward, + aten.upsample_bilinear2d, + aten._scaled_mm, +] + + +def policy_fn(ctx, op, *args, **kwargs): + if op in compute_intensive_ops: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + +context_fn = partial(create_selective_checkpoint_contexts, policy_fn) + + +def checkpoint_w_policy( + decoder_layer, *args, use_reentrant=None +): # pylint: disable=unused-argument + return checkpoint( + decoder_layer, + *args, + use_reentrant=use_reentrant, + context_fn=context_fn, + )