Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
6100baea0d offload activations to disk instead of CPU RAM 2025-05-11 14:19:49 -04:00
5 changed files with 132 additions and 10 deletions

View File

@@ -5,8 +5,11 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
DiskOffloadedGradientCheckpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return DiskOffloadedGradientCheckpointer.apply(
decoder_layer,
*args,
)
return DiskOffloadedGradientCheckpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,4 +1,4 @@
"""Unsloth checkpointing"""
"""CPU offloaded checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -0,0 +1,93 @@
"""Disk offloaded checkpointing"""
import os
import tempfile
import uuid
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class DiskOffloadedGradientCheckpointer(torch.autograd.Function):
"""
Saves both VRAM and RAM by offloading activations to disk.
Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments.
"""
# Create a temporary directory for storing tensors
_temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_")
@staticmethod
def _get_temp_file_path():
"""Generate a unique file path for tensor storage"""
return os.path.join(
DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt"
)
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
# Generate a unique file path for this tensor
file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path()
# Save tensor to disk in a non-blocking way (detached from compute)
# First move to CPU, then save
cpu_hidden_states = hidden_states.detach().cpu()
torch.save(cpu_hidden_states, file_path)
# Free CPU memory
del cpu_hidden_states
# Run forward pass
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store the path instead of the tensor
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, dY): # pylint: disable=invalid-name
# Load the hidden states from disk
hidden_states = torch.load(ctx.file_path, weights_only=True)
# Move to CUDA and prepare for gradient computation
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
# Clean up the temporary file
try:
os.remove(ctx.file_path)
except FileNotFoundError:
pass # Ignore errors in file deletion
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# pylint: disable=duplicate-code
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
@staticmethod
def cleanup():
"""Clean up the temporary directory when done"""
import shutil
try:
shutil.rmtree(
DiskOffloadedGradientCheckpointer._temp_dir
) # pylint: disable=protected-access
except FileNotFoundError:
pass

View File

@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -619,6 +622,10 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()

View File

@@ -178,9 +178,9 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
default=False
)
gradient_checkpointing: (
Literal["unsloth", "offload", "offload_disk"] | bool | None
) = Field(default=False)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
unfrozen_parameters: list[str] | None = None