From 6100baea0dac09c5c02e0d070f4824c6847244db Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 11 May 2025 14:19:49 -0400 Subject: [PATCH] offload activations to disk instead of CPU RAM --- .../utils/gradient_checkpointing/__init__.py | 30 +++++- .../{unsloth.py => offload_cpu.py} | 4 +- .../gradient_checkpointing/offload_disk.py | 93 +++++++++++++++++++ src/axolotl/utils/models.py | 9 +- src/axolotl/utils/schemas/config.py | 6 +- 5 files changed, 132 insertions(+), 10 deletions(-) rename src/axolotl/utils/gradient_checkpointing/{unsloth.py => offload_cpu.py} (95%) create mode 100644 src/axolotl/utils/gradient_checkpointing/offload_disk.py diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index f84f76d80..93c2885f3 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -5,8 +5,11 @@ from functools import partial from packaging import version -from axolotl.utils.gradient_checkpointing.unsloth import ( - Unsloth_Offloaded_Gradient_Checkpointer, +from axolotl.utils.gradient_checkpointing.offload_cpu import ( + CPU_Offloaded_Gradient_Checkpointer, +) +from axolotl.utils.gradient_checkpointing.offload_disk import ( + DiskOffloadedGradientCheckpointer, ) transformers_version = version.parse(importlib.metadata.version("transformers")) @@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument if uses_gc_layers(decoder_layer): - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( decoder_layer, *args, ) - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( + ( + decoder_layer.func.__self__ + if isinstance(decoder_layer, partial) + else decoder_layer.__self__ + ), + *args, + ) + + +def hf_grad_checkpoint_disk_offload_wrapper( + decoder_layer, *args, use_reentrant=None +): # pylint: disable=unused-argument + if uses_gc_layers(decoder_layer): + return DiskOffloadedGradientCheckpointer.apply( + decoder_layer, + *args, + ) + + return DiskOffloadedGradientCheckpointer.apply( ( decoder_layer.func.__self__ if isinstance(decoder_layer, partial) diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py similarity index 95% rename from src/axolotl/utils/gradient_checkpointing/unsloth.py rename to src/axolotl/utils/gradient_checkpointing/offload_cpu.py index 7a14614b1..bbb5ad40d 100644 --- a/src/axolotl/utils/gradient_checkpointing/unsloth.py +++ b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py @@ -1,4 +1,4 @@ -"""Unsloth checkpointing""" +"""CPU offloaded checkpointing""" # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # @@ -26,7 +26,7 @@ else: torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") -class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name +class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name torch.autograd.Function ): """ diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/utils/gradient_checkpointing/offload_disk.py new file mode 100644 index 000000000..00d7d8bb0 --- /dev/null +++ b/src/axolotl/utils/gradient_checkpointing/offload_disk.py @@ -0,0 +1,93 @@ +"""Disk offloaded checkpointing""" + +import os +import tempfile +import uuid + +import torch + +torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") +torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") + + +class DiskOffloadedGradientCheckpointer(torch.autograd.Function): + """ + Saves both VRAM and RAM by offloading activations to disk. + Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments. + """ + + # Create a temporary directory for storing tensors + _temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_") + + @staticmethod + def _get_temp_file_path(): + """Generate a unique file path for tensor storage""" + return os.path.join( + DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt" + ) + + @staticmethod + @torch_cuda_amp_custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + # Generate a unique file path for this tensor + file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path() + + # Save tensor to disk in a non-blocking way (detached from compute) + # First move to CPU, then save + cpu_hidden_states = hidden_states.detach().cpu() + torch.save(cpu_hidden_states, file_path) + + # Free CPU memory + del cpu_hidden_states + + # Run forward pass + with torch.no_grad(): + output = forward_function(hidden_states, *args) + + # Store the path instead of the tensor + ctx.save_for_backward(torch.tensor([0])) # Dummy tensor + ctx.file_path = file_path + ctx.forward_function = forward_function + ctx.args = args + return output + + @staticmethod + @torch_cuda_amp_custom_bwd + def backward(ctx, dY): # pylint: disable=invalid-name + # Load the hidden states from disk + hidden_states = torch.load(ctx.file_path, weights_only=True) + + # Move to CUDA and prepare for gradient computation + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad = True + + # Clean up the temporary file + try: + os.remove(ctx.file_path) + except FileNotFoundError: + pass # Ignore errors in file deletion + + # Compute gradients + with torch.enable_grad(): + output = ctx.forward_function(hidden_states, *ctx.args) + # pylint: disable=duplicate-code + torch.autograd.backward(output, dY) + + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + + @staticmethod + def cleanup(): + """Clean up the temporary directory when done""" + import shutil + + try: + shutil.rmtree( + DiskOffloadedGradientCheckpointer._temp_dir + ) # pylint: disable=protected-access + except FileNotFoundError: + pass diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6aa4dd162..89533e121 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -70,7 +70,10 @@ from axolotl.utils.distributed import ( is_local_main_process, is_main_process, ) -from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper +from axolotl.utils.gradient_checkpointing import ( + hf_grad_checkpoint_disk_offload_wrapper, + hf_grad_checkpoint_offload_wrapper, +) from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -619,6 +622,10 @@ class ModelLoader: if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if self.cfg.gradient_checkpointing == "offload_disk": + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_disk_offload_wrapper + ) if self.cfg.flash_attention: self.patch_attention() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..03ff06fd4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -178,9 +178,9 @@ class AxolotlInputConfig( # torch_dtype: torch.dtype | None - gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field( - default=False - ) + gradient_checkpointing: ( + Literal["unsloth", "offload", "offload_disk"] | bool | None + ) = Field(default=False) gradient_checkpointing_kwargs: dict[str, Any] | None = None unfrozen_parameters: list[str] | None = None