Compare commits
1 Commits
fa-check
...
offload-ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6100baea0d |
@@ -5,8 +5,11 @@ from functools import partial
|
||||
|
||||
from packaging import version
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||
DiskOffloadedGradientCheckpointer,
|
||||
)
|
||||
|
||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
else decoder_layer.__self__
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return DiskOffloadedGradientCheckpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return DiskOffloadedGradientCheckpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Unsloth checkpointing"""
|
||||
"""CPU offloaded checkpointing"""
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
@@ -26,7 +26,7 @@ else:
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
torch.autograd.Function
|
||||
):
|
||||
"""
|
||||
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Disk offloaded checkpointing"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class DiskOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||
"""
|
||||
Saves both VRAM and RAM by offloading activations to disk.
|
||||
Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments.
|
||||
"""
|
||||
|
||||
# Create a temporary directory for storing tensors
|
||||
_temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_")
|
||||
|
||||
@staticmethod
|
||||
def _get_temp_file_path():
|
||||
"""Generate a unique file path for tensor storage"""
|
||||
return os.path.join(
|
||||
DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
# Generate a unique file path for this tensor
|
||||
file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path()
|
||||
|
||||
# Save tensor to disk in a non-blocking way (detached from compute)
|
||||
# First move to CPU, then save
|
||||
cpu_hidden_states = hidden_states.detach().cpu()
|
||||
torch.save(cpu_hidden_states, file_path)
|
||||
|
||||
# Free CPU memory
|
||||
del cpu_hidden_states
|
||||
|
||||
# Run forward pass
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
# Store the path instead of the tensor
|
||||
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||
ctx.file_path = file_path
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, dY): # pylint: disable=invalid-name
|
||||
# Load the hidden states from disk
|
||||
hidden_states = torch.load(ctx.file_path, weights_only=True)
|
||||
|
||||
# Move to CUDA and prepare for gradient computation
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
# Clean up the temporary file
|
||||
try:
|
||||
os.remove(ctx.file_path)
|
||||
except FileNotFoundError:
|
||||
pass # Ignore errors in file deletion
|
||||
|
||||
# Compute gradients
|
||||
with torch.enable_grad():
|
||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||
# pylint: disable=duplicate-code
|
||||
torch.autograd.backward(output, dY)
|
||||
|
||||
return (
|
||||
None,
|
||||
hidden_states.grad,
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
@staticmethod
|
||||
def cleanup():
|
||||
"""Clean up the temporary directory when done"""
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(
|
||||
DiskOffloadedGradientCheckpointer._temp_dir
|
||||
) # pylint: disable=protected-access
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
is_main_process,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
@@ -619,6 +622,10 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_disk_offload_wrapper
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
@@ -178,9 +178,9 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: torch.dtype | None
|
||||
|
||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing: (
|
||||
Literal["unsloth", "offload", "offload_disk"] | bool | None
|
||||
) = Field(default=False)
|
||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
Reference in New Issue
Block a user