Activation Offloading w CUDA Streams (#2900) [skip ci]

* use cuda streams for activation offloading

* use torch native ops

* update cfg schema for streams

* fix literal constructor for set

* use context for training step so it doesn't affect evals

* disable streams

* auto gc on eval steps

* use activation_offloading config arg

* add docs for gradient checkpointing

* handle validation for gc/ao

* use cuda streams for act offloading

* add more validation for AC w/o GC

* fix docs

* move activation_offloading lower in definition so it doesn't break args/kwargs

* fix kd due to import order
This commit is contained in:
Wing Lian
2025-07-14 20:10:20 -04:00
committed by GitHub
parent aa684122f1
commit 99187cd208
14 changed files with 154 additions and 186 deletions

View File

@@ -276,6 +276,7 @@ website:
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- section: "Troubleshooting"
contents:

View File

@@ -0,0 +1,29 @@
---
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
models by reducing the memory footprint and improving computational efficiency.
### Enabling Gradient Checkpointing
```yaml
gradient_checkpointing: true
```
### Enabling Activation Offloading
```yaml
gradient_checkpointing: true # required for activation offloading
activation_offloading: true
```
Activation offloading variants:
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
to overlap the communications and computations when offloading.
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.

View File

@@ -434,7 +434,11 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)

View File

@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
@@ -48,6 +49,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""

View File

@@ -3,6 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -0,0 +1,37 @@
"""
Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
from trl.models.activation_offloading import get_act_offloading_ctx_manager
class ActivationOffloadingMixin(Trainer):
"""
Trainer mixin class for activation checkpointing w offloading
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.activation_offloading:
self.activation_offload_context = get_act_offloading_ctx_manager(
self.model, use_streams=True
)
else:
self.activation_offload_context = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self.activation_offload_context:
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)

View File

@@ -217,6 +217,11 @@ class AxolotlTrainingMixins:
},
)
activation_offloading: bool | None = field(
default=None,
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -198,12 +198,22 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
embeddings_len = (

View File

@@ -7,7 +7,6 @@ import importlib.util
from functools import cached_property
import addict
import torch
import transformers
from transformers import PretrainedConfig, PreTrainedModel
@@ -168,28 +167,19 @@ class PatchManager:
def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
if (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "legacy"
):
from axolotl.monkeypatch.gradient_checkpointing import (
CheckpointFunctionWithCPUOffload,
hf_grad_checkpoint_offload_wrapper,
)
if (
self.cfg.gradient_checkpointing_kwargs
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
):
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_offload_wrapper
)
else:
transformers.modeling_utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
torch.utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
elif (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "offload_disk"
):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
)

View File

@@ -6,7 +6,6 @@ from functools import partial
from packaging import version
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
CheckpointFunctionWithCPUOffload,
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (

View File

@@ -14,18 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import torch
from packaging import version
from torch.utils.checkpoint import (
_get_autocast_kwargs,
_get_device_module,
_infer_device_type,
check_backward_validity,
detach_variable,
get_device_states,
set_device_states,
)
@@ -76,153 +69,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
) + (
None,
) * len(ctx.args)
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
"""
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.device_type = _infer_device_type(*args)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
ctx.device_type
)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_device_in_fwd = False
device_module = _get_device_module(ctx.device_type)
if getattr(device_module, "_initialized", False):
ctx.had_device_in_fwd = True
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
# x = None
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# cpu-offload
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
# upstream could accept a list of arg indices to offload
if i == 0:
# print(f"{arg.shape=}")
ctx.x_device = arg.device
ctx.x_requires_grad = arg.requires_grad
t = arg.detach().cpu()
else:
t = arg
tensor_inputs.append(t)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if (
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
):
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
" To resolve this error, you can either set use_reentrant=False,"
" or call .backward() without passing the `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
if i == 0:
t = (
tensors[i]
.to(ctx.x_device)
.detach()
.requires_grad_(ctx.x_requires_grad)
)
else:
t = tensors[i]
inputs[idx] = t
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(
devices=rng_devices,
enabled=ctx.preserve_rng_state,
device_type=ctx.device_type,
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
if has_device_type:
# newer pytorch (as early as 2.7)
set_device_states(
ctx.fwd_devices,
ctx.fwd_device_states,
device_type=ctx.device_type,
)
else:
# older pytorch (at least 2.4)
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))
device_autocast_ctx = (
torch.amp.autocast(
device_type=ctx.device_type, **ctx.device_autocast_kwargs
)
if torch.amp.is_autocast_available(ctx.device_type)
else contextlib.nullcontext()
)
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None) + grads

View File

@@ -841,21 +841,35 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""
def __init__(self, gc_steps=None):
self.gc_steps = gc_steps
def __init__(self, gc_steps: int | None = -1):
self.gc_steps: int = gc_steps or -1
self.next_gc_on_begin_step: int = -1
def _gc(self):
torch.cuda.empty_cache()
gc.collect()
def on_step_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.next_gc_on_begin_step == state.global_step:
self._gc()
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
if control.should_evaluate:
# automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer
self._gc()
# also GC on the start of the next step after the eval
self.next_gc_on_begin_step = state.global_step + 1
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
self._gc()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
torch.cuda.empty_cache()
gc.collect()
self._gc()
def colab_inference_post_train_callback(trainer: Trainer):

View File

@@ -320,7 +320,12 @@ class AxolotlInputConfig(
},
)
gc_steps: int | None = None
gc_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)."
},
)
bf16: Literal["auto"] | bool | None = Field(
default="auto",
@@ -360,6 +365,12 @@ class AxolotlInputConfig(
"description": "Additional kwargs to pass to the trainer for gradient checkpointing"
},
)
activation_offloading: Literal["legacy", "disk"] | bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
unfrozen_parameters: list[str] | None = None

View File

@@ -1017,6 +1017,28 @@ class ModelCompatibilityValidationMixin:
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_gradient_checkpointing_w_offload(self):
if self.gradient_checkpointing == "offload":
LOG.warning(
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
)
self.gradient_checkpointing = True
self.activation_offloading = True
if self.gradient_checkpointing == "offload_disk":
LOG.warning(
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
)
self.gradient_checkpointing = True
self.activation_offloading = "disk"
return self
@model_validator(mode="after")
def check_activation_offloading_wo_gc(self):
if self.activation_offloading and not self.gradient_checkpointing:
raise ValueError("activation_offloading requires gradient_checkpointing")
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True: