Compare commits
1 Commits
quantize-p
...
offload-ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6100baea0d |
@@ -5,8 +5,11 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
|
)
|
||||||
|
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||||
|
DiskOffloadedGradientCheckpointer,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if uses_gc_layers(decoder_layer):
|
if uses_gc_layers(decoder_layer):
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer,
|
decoder_layer,
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
(
|
||||||
|
decoder_layer.func.__self__
|
||||||
|
if isinstance(decoder_layer, partial)
|
||||||
|
else decoder_layer.__self__
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
if uses_gc_layers(decoder_layer):
|
||||||
|
return DiskOffloadedGradientCheckpointer.apply(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DiskOffloadedGradientCheckpointer.apply(
|
||||||
(
|
(
|
||||||
decoder_layer.func.__self__
|
decoder_layer.func.__self__
|
||||||
if isinstance(decoder_layer, partial)
|
if isinstance(decoder_layer, partial)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Unsloth checkpointing"""
|
"""CPU offloaded checkpointing"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -26,7 +26,7 @@ else:
|
|||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
torch.autograd.Function
|
torch.autograd.Function
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Disk offloaded checkpointing"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
class DiskOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Saves both VRAM and RAM by offloading activations to disk.
|
||||||
|
Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a temporary directory for storing tensors
|
||||||
|
_temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_temp_file_path():
|
||||||
|
"""Generate a unique file path for tensor storage"""
|
||||||
|
return os.path.join(
|
||||||
|
DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_fwd
|
||||||
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
|
# Generate a unique file path for this tensor
|
||||||
|
file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path()
|
||||||
|
|
||||||
|
# Save tensor to disk in a non-blocking way (detached from compute)
|
||||||
|
# First move to CPU, then save
|
||||||
|
cpu_hidden_states = hidden_states.detach().cpu()
|
||||||
|
torch.save(cpu_hidden_states, file_path)
|
||||||
|
|
||||||
|
# Free CPU memory
|
||||||
|
del cpu_hidden_states
|
||||||
|
|
||||||
|
# Run forward pass
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
|
# Store the path instead of the tensor
|
||||||
|
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||||
|
ctx.file_path = file_path
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_bwd
|
||||||
|
def backward(ctx, dY): # pylint: disable=invalid-name
|
||||||
|
# Load the hidden states from disk
|
||||||
|
hidden_states = torch.load(ctx.file_path, weights_only=True)
|
||||||
|
|
||||||
|
# Move to CUDA and prepare for gradient computation
|
||||||
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
|
# Clean up the temporary file
|
||||||
|
try:
|
||||||
|
os.remove(ctx.file_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # Ignore errors in file deletion
|
||||||
|
|
||||||
|
# Compute gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
torch.autograd.backward(output, dY)
|
||||||
|
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
hidden_states.grad,
|
||||||
|
) + (
|
||||||
|
None,
|
||||||
|
) * len(ctx.args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cleanup():
|
||||||
|
"""Clean up the temporary directory when done"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(
|
||||||
|
DiskOffloadedGradientCheckpointer._temp_dir
|
||||||
|
) # pylint: disable=protected-access
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
|
)
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
@@ -619,6 +622,10 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|||||||
@@ -178,9 +178,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: torch.dtype | None
|
# torch_dtype: torch.dtype | None
|
||||||
|
|
||||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
gradient_checkpointing: (
|
||||||
default=False
|
Literal["unsloth", "offload", "offload_disk"] | bool | None
|
||||||
)
|
) = Field(default=False)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user