Compare commits
2 Commits
xformers-w
...
offload-ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6100baea0d | ||
|
|
47e0e71bc8 |
@@ -5,8 +5,11 @@ from functools import partial
|
||||
|
||||
from packaging import version
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||
DiskOffloadedGradientCheckpointer,
|
||||
)
|
||||
|
||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
else decoder_layer.__self__
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return DiskOffloadedGradientCheckpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return DiskOffloadedGradientCheckpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Unsloth checkpointing"""
|
||||
"""CPU offloaded checkpointing"""
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
@@ -26,7 +26,7 @@ else:
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
torch.autograd.Function
|
||||
):
|
||||
"""
|
||||
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
93
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Disk offloaded checkpointing"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class DiskOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||
"""
|
||||
Saves both VRAM and RAM by offloading activations to disk.
|
||||
Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments.
|
||||
"""
|
||||
|
||||
# Create a temporary directory for storing tensors
|
||||
_temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_")
|
||||
|
||||
@staticmethod
|
||||
def _get_temp_file_path():
|
||||
"""Generate a unique file path for tensor storage"""
|
||||
return os.path.join(
|
||||
DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
# Generate a unique file path for this tensor
|
||||
file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path()
|
||||
|
||||
# Save tensor to disk in a non-blocking way (detached from compute)
|
||||
# First move to CPU, then save
|
||||
cpu_hidden_states = hidden_states.detach().cpu()
|
||||
torch.save(cpu_hidden_states, file_path)
|
||||
|
||||
# Free CPU memory
|
||||
del cpu_hidden_states
|
||||
|
||||
# Run forward pass
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
# Store the path instead of the tensor
|
||||
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||
ctx.file_path = file_path
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, dY): # pylint: disable=invalid-name
|
||||
# Load the hidden states from disk
|
||||
hidden_states = torch.load(ctx.file_path, weights_only=True)
|
||||
|
||||
# Move to CUDA and prepare for gradient computation
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
# Clean up the temporary file
|
||||
try:
|
||||
os.remove(ctx.file_path)
|
||||
except FileNotFoundError:
|
||||
pass # Ignore errors in file deletion
|
||||
|
||||
# Compute gradients
|
||||
with torch.enable_grad():
|
||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||
# pylint: disable=duplicate-code
|
||||
torch.autograd.backward(output, dY)
|
||||
|
||||
return (
|
||||
None,
|
||||
hidden_states.grad,
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
@staticmethod
|
||||
def cleanup():
|
||||
"""Clean up the temporary directory when done"""
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(
|
||||
DiskOffloadedGradientCheckpointer._temp_dir
|
||||
) # pylint: disable=protected-access
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
is_main_process,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
@@ -619,6 +622,10 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_disk_offload_wrapper
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
@@ -78,15 +78,11 @@ def pack_group(
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
"""
|
||||
# Get sorting indices and sort lengths in descending order
|
||||
indices = np.argsort(sequence_lengths)[::-1]
|
||||
sorted_lengths = sequence_lengths[indices]
|
||||
|
||||
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
||||
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
||||
|
||||
for seq_id, size in enumerate(sorted_lengths):
|
||||
global_idx = indices[seq_id] + group_offset
|
||||
for seq_id, size in enumerate(sequence_lengths):
|
||||
global_idx = seq_id + group_offset
|
||||
|
||||
# Try to place sequence in existing bins
|
||||
add_new_bin = True
|
||||
|
||||
@@ -178,9 +178,9 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: torch.dtype | None
|
||||
|
||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing: (
|
||||
Literal["unsloth", "offload", "offload_disk"] | bool | None
|
||||
) = Field(default=False)
|
||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
|
||||
Reference in New Issue
Block a user