Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
5e50d1e8f0 batch flattening with xformers too 2025-05-08 18:23:25 -04:00
Wing Lian
7fb01f0461 also support xformers w/o packing 2025-05-08 15:22:48 -04:00
9 changed files with 28 additions and 147 deletions

View File

@@ -5,11 +5,8 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
DiskOffloadedGradientCheckpointer,
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -29,31 +26,12 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return CPU_Offloaded_Gradient_Checkpointer.apply(
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return DiskOffloadedGradientCheckpointer.apply(
decoder_layer,
*args,
)
return DiskOffloadedGradientCheckpointer.apply(
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,93 +0,0 @@
"""Disk offloaded checkpointing"""
import os
import tempfile
import uuid
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class DiskOffloadedGradientCheckpointer(torch.autograd.Function):
"""
Saves both VRAM and RAM by offloading activations to disk.
Greater hit to performance than RAM offloading, but useful for extremely memory-constrained environments.
"""
# Create a temporary directory for storing tensors
_temp_dir = tempfile.mkdtemp(prefix="disk_checkpoint_")
@staticmethod
def _get_temp_file_path():
"""Generate a unique file path for tensor storage"""
return os.path.join(
DiskOffloadedGradientCheckpointer._temp_dir, f"{uuid.uuid4()}.pt"
)
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
# Generate a unique file path for this tensor
file_path = DiskOffloadedGradientCheckpointer._get_temp_file_path()
# Save tensor to disk in a non-blocking way (detached from compute)
# First move to CPU, then save
cpu_hidden_states = hidden_states.detach().cpu()
torch.save(cpu_hidden_states, file_path)
# Free CPU memory
del cpu_hidden_states
# Run forward pass
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store the path instead of the tensor
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, dY): # pylint: disable=invalid-name
# Load the hidden states from disk
hidden_states = torch.load(ctx.file_path, weights_only=True)
# Move to CUDA and prepare for gradient computation
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
# Clean up the temporary file
try:
os.remove(ctx.file_path)
except FileNotFoundError:
pass # Ignore errors in file deletion
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# pylint: disable=duplicate-code
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)
@staticmethod
def cleanup():
"""Clean up the temporary directory when done"""
import shutil
try:
shutil.rmtree(
DiskOffloadedGradientCheckpointer._temp_dir
) # pylint: disable=protected-access
except FileNotFoundError:
pass

View File

@@ -1,4 +1,4 @@
"""CPU offloaded checkpointing"""
"""Unsloth checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -70,10 +70,7 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -559,7 +556,7 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
if self.cfg.xformers_attention:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
@@ -622,10 +619,6 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()
@@ -778,13 +771,6 @@ class ModelLoader:
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,

View File

@@ -78,11 +78,15 @@ def pack_group(
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True

View File

@@ -178,9 +178,9 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: (
Literal["unsloth", "offload", "offload_disk"] | bool | None
) = Field(default=False)
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
unfrozen_parameters: list[str] | None = None
@@ -475,8 +475,14 @@ class AxolotlInputConfig(
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if (
not data.get("flash_attention")
and not data.get("xformers_attention")
and not batch_flattening_auto
):
raise ValueError(
"batch_flattening requires flash attention or xformers"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:

View File

@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_run_group: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

View File

@@ -106,4 +106,3 @@ class TestBatchedSamplerPacking:
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))