Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
22684ec98f feat: add draft wizard cli 2025-05-14 15:38:14 +07:00
NanoCode012
6db60ac520 fix: add missing config to schema 2025-05-14 15:38:14 +07:00
11 changed files with 456 additions and 587 deletions

View File

@@ -139,8 +139,7 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- utils.gradient_checkpointing.unsloth
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:

View File

@@ -539,7 +539,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing

View File

@@ -342,6 +342,13 @@ def delinearize_llama4(model: str, output: str) -> None:
do_delinearize_llama4(model, output)
@cli.command()
def wizard():
from axolotl.cli.wizard import do_wizard
do_wizard()
cli.add_command(lm_eval)

429
src/axolotl/cli/wizard.py Normal file
View File

@@ -0,0 +1,429 @@
"""Wizard for creating yaml configs."""
import click
import torch
import yaml
from packaging import version
from transformers.training_args import OptimizerNames
from axolotl.cli.art import print_axolotl_text_art
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
def do_wizard():
print_axolotl_text_art()
# Ask where to save the config
cfg = DictDefault({})
config_path = click.prompt(
"Where do you want to save the config?", type=str, default="config.yaml"
)
# Ask base model
base_model = click.prompt("What base model do you want to use?", type=str)
cfg["base_model"] = base_model.strip()
# Ask whether want to enable Vision model
# TODO: check if model has vision layers instead of asking user
train_vision_model = click.confirm(
"If this model has vision layers, do you want to train them?", default=False
)
if train_vision_model:
cfg["processor_type"] = "AutoProcessor"
cfg["skip_prepare_dataset"] = True
cfg["remove_unused_columns"] = False
cfg["sample_packing"] = False
# Ask whether they want to set any advanced model features (custom tokenizer, custom config, etc)
advanced_model_features = click.confirm(
"Do you want to set any advanced model features? (custom tokenizer, custom config, remote code etc)",
default=False,
)
if advanced_model_features:
# Ask whether they want to use a custom config
base_model_config = click.prompt(
"What model config do you want to use? (leave blank for default)",
type=str,
default="",
)
if base_model_config:
cfg["base_model_config"] = base_model_config
# Ask whether they want to use a specific revision of the model
revision_of_model = click.prompt(
"What revision of the model do you want to use? (leave blank for default)",
type=str,
default="",
)
if revision_of_model:
cfg["revision_of_model"] = revision_of_model
# Ask whether they want to use a custom tokenizer
tokenizer_config = click.prompt(
"What tokenizer do you want to use? (leave blank for default)",
type=str,
default="",
)
if tokenizer_config:
cfg["tokenizer_config"] = tokenizer_config
# Ask whether they want to use remote code
trust_remote_code = click.confirm(
"Do you want to use remote code?", default=False
)
if trust_remote_code:
cfg["trust_remote_code"] = trust_remote_code
# Whether to resize token embeddings
resize_token_embeddings_to_32x = click.confirm(
"Do you want to resize token embeddings to 32x?", default=False
)
if resize_token_embeddings_to_32x:
cfg["resize_token_embeddings_to_32x"] = resize_token_embeddings_to_32x
# Whether to shrink embeddings to len(tokenizer)
shrink_embeddings = click.confirm(
"Do you want to shrink embeddings to len(tokenizer)?", default=False
)
if shrink_embeddings:
cfg["shrink_embeddings"] = shrink_embeddings
# Whether to skip upcast embeddings
embeddings_skip_upcast = click.confirm(
"Do you want to skip upcast embeddings?", default=False
)
if embeddings_skip_upcast:
cfg["embeddings_skip_upcast"] = embeddings_skip_upcast
# Whether to random init weights
random_init_weights = click.confirm(
"Do you want to random init weights?", default=False
)
if random_init_weights:
cfg["random_init_weights"] = random_init_weights
# Get model type
config = load_model_config(cfg)
model_type = config.model_type
# Ask sequence length
sequence_length = click.prompt("What sequence length do you want to use?", type=int)
cfg["sequence_length"] = sequence_length
# Whether to turn on sample packing
if cfg["sample_packing"] is None:
cfg["sample_packing"] = click.confirm(
"Do you want to turn on sample packing? This will speed up training by packing multiple samples into a single batch.",
default=True,
)
if cfg["sample_packing"]:
cfg["pad_to_sequence_len"] = True
# Whether to turn off eval sample packing
no_eval_sample_packing = click.confirm(
"Do you want to turn off eval sample packing? This will slow down evaluation but is recommended if you are using a small validation set.",
default=False,
)
if no_eval_sample_packing:
cfg["eval_sample_packing"] = False
# Hardware check
try:
is_ampere_or_newer = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere_or_newer = False
except AssertionError: # this is raised if no cuda is available
is_ampere_or_newer = False
# Get num gpus
try:
num_gpus = torch.cuda.device_count()
except RuntimeError:
num_gpus = 0
# Get torch version
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
is_torch_2_6_or_newer = version.parse(torch_version) >= version.parse("2.6.0")
# Whether to turn on attention
opt = ["xformers", "sdp"]
if is_ampere_or_newer:
opt.append("flash")
if is_torch_2_6_or_newer:
opt.append("flex")
if cfg["sample_packing"]:
if "flash" in opt:
default_opt = "flash"
elif "flex" in opt:
default_opt = "flex"
else:
default_opt = opt[0]
attention = click.prompt(
"Which attention backend do you want to use? Sample packing requires an attention backend to be set.",
type=click.Choice(opt),
default=default_opt,
)
else:
# non-sample packing supports no attention and S2
opt.extend(["none", "s2"])
attention = click.prompt(
"Which attention backend do you want to use?",
type=click.Choice(opt),
default="none",
)
if attention == "none":
attention = None
# TODO: if xformers, check if FA is installed
# TODO: flex doc mentioned requiring seq len to be divisible by 128. Unclear if limitation still exists
# TODO: requires #2489
cfg["attention"] = attention
# Whether to turn on gradient checkpointing
# TODO: need to wait for offload_disk PR to be merged
gradient_checkpointing = click.prompt(
"Which gradient checkpointing strategy do you want to use?",
type=click.Choice(["none", "true", "offload", "offload_disk"]),
default="true",
)
if gradient_checkpointing == "none":
gradient_checkpointing = False
elif gradient_checkpointing == "true":
gradient_checkpointing = True
# Ask whether to set use_reentrant
# TODO: get correct defaults based on SFT/RL mode and single/multigpu
# use_reentrant = click.confirm(
# "Do you want to set use_reentrant?",
# default=True,
# )
# if use_reentrant:
# cfg["use_reentrant"] = use_reentrant
# Optimizer
cfg["optimizer"] = click.prompt(
"Which optimizer do you want to use?",
type=click.Choice((OptimizerNames | CustomSupportedOptimizers)),
default=OptimizerNames.ADAMW_TORCH_FUSED,
)
cfg["lr_scheduler"] = click.prompt(
"Which learning rate scheduler do you want to use?",
type=click.Choice(
[
"cosine",
"one_cycle",
"rex",
"log_sweep",
"linear",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
"inverse_sqrt",
"reduce_lr_on_plateau",
"cosine_with_min_lr",
"warmup_stable_decay",
]
),
default="cosine",
)
# Plugins
cfg["plugins"] = []
# Whether to turn on cut cross entropy
if is_ampere_or_newer:
# Note: This may error if users don't have CCE installed
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
CUT_CROSS_ENTROPY_MODEL_MAPPING,
)
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
cut_cross_entropy = click.confirm(
"Do you want to turn on cut cross entropy? This will save VRAM if the model has a large vocab size.",
default=True,
)
if cut_cross_entropy:
cfg["plugins"].append(
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
)
cfg["cut_cross_entropy"] = True
use_liger_kernel = click.confirm(
"Do you want to use the liger kernel? This will speed up training and save VRAM.",
default=True,
)
if use_liger_kernel:
cfg["plugins"].append("axolotl.integrations.liger.LigerPlugin")
cfg["liger_rope"] = click.confirm(
"Do you want to enable liger rope?",
default=True,
)
cfg["liger_rms_norm"] = click.confirm(
"Do you want to enable liger rms norm?",
default=True,
)
cfg["liger_glu_activation"] = click.confirm(
"Do you want to enable liger glu activation?",
default=True,
)
cfg["liger_layer_norm"] = click.confirm(
"Do you want to enable liger layer norm?",
default=True,
)
if cfg["cut_cross_entropy"] is not True:
cfg["liger_fused_linear_cross_entropy"] = click.confirm(
"Do you want to enable liger fused linear cross entropy?",
default=True,
)
# TODO: lora kernels (but they auto enable via validator already)
# TODO: is there incompat between torch compile and liger?
cfg["torch_compile"] = click.confirm(
"Do you want to enable torch compile?",
default=True,
)
# Multi-gpu
if num_gpus > 1:
# Ask whether to use DDP/Deepspeed/FSDP
multi_gpu_mode = click.prompt(
"Which multi-gpu mode do you want to use?",
type=click.Choice(["ddp", "deepspeed", "fsdp"]),
default="ddp",
)
if multi_gpu_mode == "deepspeed":
# Ask which deepspeed config to use
cfg["deepspeed"] = click.prompt(
"Which deepspeed config do you want to use? The higher the number, the more VRAM you will save, but the slower it will run.",
type=click.Choice(
[
"zero1.json",
"zero1_torch_compile.json",
"zero2.json",
"zero3.json",
"zero3_bf16.json",
"zero3_bf16_cpuoffload_all.json",
"zero3_bf16_cpuoffload_params.json",
]
),
default="zero1.json",
)
elif multi_gpu_mode == "fsdp":
fsdp_version = click.prompt(
"Which fsdp version do you want to use?",
type=click.Choice([1, 2]),
default=2,
)
# TODO: Handle FSDP config
if fsdp_version == 1:
cfg["fsdp"] = ["full_shard", "auto_wrap"]
# Ask which state dict type to use
fsdp_state_dict_type = click.prompt(
"Which fsdp state dict type do you want to use?",
type=click.Choice(["FULL_STATE_DICT", "SHARDED_STATE_DICT"]),
default="FULL_STATE_DICT",
)
fsdp_offload_params = click.confirm(
"Do you want to offload parameters?",
default=True,
)
# TODO: can we load the model class and auto pull a default for this?
fsdp_transformer_layer_cls_to_wrap = click.prompt(
"Which transformer layer class to wrap? It is usually the Decoder layer class.",
type=str,
)
# TODO: add other options
cfg["fsdp_config"] = {
"state_dict_type": fsdp_state_dict_type,
"offload_params": fsdp_offload_params,
"transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
}
elif fsdp_version == 2:
raise NotImplementedError()
# Training mode (sft or rl)
training_mode = click.prompt(
"Which training mode do you want to use?",
type=click.Choice(["sft", "rl"]),
default="sft",
)
if training_mode == "rl":
cfg["rl"] = click.prompt(
"Which rl mode do you want to use?",
type=click.Choice(["dpo", "ipo", "orpo", "kto", "grpo", "simpo"]),
)
# TODO: handle RL options
# Whether to use adapter
# Get batch/grad accu
# Get learning rate
# Get weight decay
# Get max grad norm
# Get num train epochs
# Get warmup ratio
# Get save ratio
# Get eval ratio
# Get dataset config
# Load metric tracker
# Save config to yaml
# TODO: improve output yaml formatting. Need to add comments to help separate sections
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(cfg.to_dict(), f, sort_keys=False)

View File

@@ -289,18 +289,16 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
else:
if cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
trainer.accelerator.wait_for_everyone()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin

View File

@@ -5,11 +5,8 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
Disco,
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -29,31 +26,12 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return CPU_Offloaded_Gradient_Checkpointer.apply(
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Disco.apply(
decoder_layer,
*args,
)
return Disco.apply(
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,531 +0,0 @@
"""
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
"""
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
import tempfile
import threading
import time
import uuid
from collections import deque
from concurrent.futures import Future
from typing import Dict
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
class DiskOffloadManager:
"""
Manages offloaded tensors and handles prefetching in a separate thread.
Includes synchronization to prevent race conditions.
"""
def __init__(
self,
prefetch_size: int = 3,
prefetch_to_gpu: bool = True,
save_workers: int = 4,
):
"""
Args:
prefetch_size: Maximum number of tensors to prefetch in the background.
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
save_workers: Maximum number of concurrent save operations.
"""
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
# Track tensor paths and their status
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
self.file_locks: Dict[str, threading.Lock] = (
{}
) # Maps file_path -> threading.Lock()
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
self.file_status: Dict[str, str] = {}
self.max_prefetch = prefetch_size
self.prefetch_to_gpu = prefetch_to_gpu
# Thread synchronization
self.manager_lock = threading.RLock() # Used for thread-safe operations
# Prefetch queue and cache
self.prefetch_queue: queue.Queue = queue.Queue()
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
# Save queue and thread pool
self.save_queue: queue.Queue = queue.Queue()
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
self.save_futures: Dict[str, Future] = {}
self.save_semaphore = threading.Semaphore(
save_workers * 2
) # Limit concurrent save operations
# Start prefetch worker thread
self.stop_event = threading.Event()
# start multiple threads for prefetching
self.prefetch_worker_count = 2
self.prefetch_workers = []
for _ in range(self.prefetch_worker_count):
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
worker.start()
self.prefetch_workers.append(worker)
# Start save worker thread
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
self.save_worker.start()
self.idx = 0
atexit.register(self.cleanup)
def _save_worker(self):
"""Background thread that processes the save queue"""
while not self.stop_event.is_set():
try:
save_item = self.save_queue.get(timeout=0.5)
if save_item is None:
continue
tensor, file_path = save_item
# Submit the save task to the thread pool
future = self.save_pool.submit(
self._save_tensor_to_disk, tensor, file_path
)
with self.manager_lock:
self.save_futures[file_path] = future
self.save_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
"""Actually save the tensor to disk"""
try:
# Save tensor to disk
cpu_tensor = tensor.detach().cpu()
torch.save(cpu_tensor, file_path)
del cpu_tensor
with self.manager_lock:
# Mark file as ready
self.file_status[file_path] = "ready"
# Release semaphore
self.save_semaphore.release()
return True
except FileNotFoundError as e:
logger.error(f"Error saving tensor to {file_path}: {e}")
with self.manager_lock:
self.file_status[file_path] = "error"
# Release semaphore
self.save_semaphore.release()
return False
def _prefetch_worker(self):
"""Background thread that loads tensors from disk ahead of time"""
while not self.stop_event.is_set():
try:
file_path = self.prefetch_queue.get(timeout=0.5)
if file_path is None:
continue
# Check if file is available and not already in cache
with self.manager_lock:
if (
file_path not in self.file_status
or self.file_status[file_path] == "deleted"
):
self.prefetch_queue.task_done()
if file_path in self.prefetch_cache:
self.prefetch_queue.task_done()
continue
# If file is still being saved, wait for it
if (
self.file_status[file_path] == "saving"
and file_path in self.save_futures
):
# Re-queue this prefetch request with a little delay
self.prefetch_queue.task_done()
time.sleep(0.1)
self.prefetch_queue.put(file_path)
continue
# Mark file as being prefetched
self.file_status[file_path] = "prefetching"
# Load tensor from disk and store in cache
try:
if os.path.exists(file_path):
if self.prefetch_to_gpu:
tensor = torch.load(
file_path,
map_location=torch.device("cuda"),
weights_only=True,
)
else:
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.prefetch_cache[file_path] = tensor
self.file_status[file_path] = "ready"
else:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(
f"Prefetch error: File not found {file_path}"
)
self.file_status[file_path] = "missing"
except FileNotFoundError as e:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(f"Prefetch error for {file_path}: {e}")
self.file_status[file_path] = "error"
self.prefetch_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def save_tensor(self, tensor: torch.Tensor):
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
# Generate unique file path
self.idx += 1
file_path: str = os.path.join(
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
)
with self.manager_lock:
# Mark file as being saved
self.file_locks[file_path] = threading.Lock()
self.file_status[file_path] = "saving"
# Add to history
self.tensor_paths.append(file_path)
# Acquire semaphore to limit concurrent save operations
self.save_semaphore.acquire() # pylint: disable=consider-using-with
# Queue tensor for saving in background
self.save_queue.put((tensor.detach(), file_path))
return file_path
def wait_for_save(self, file_path, timeout=None) -> None:
"""Wait for a tensor to be saved to disk"""
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
with self.manager_lock:
if self.file_status.get(file_path) == "ready":
return
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
return
if file_path in self.save_futures:
future = self.save_futures[file_path]
if future.done():
return
# Small sleep to prevent CPU spinning
time.sleep(0.01)
# Timeout
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
return
def load_tensor(self, file_path, target_device="cuda"):
"""Load tensor from disk or prefetch cache with proper synchronization"""
# Wait for tensor to be saved if it's still in progress
self.wait_for_save(file_path)
tensor = None
# Try to get from cache first
with self.manager_lock:
# Check if tensor is already in cache
if file_path in self.prefetch_cache:
tensor = self.prefetch_cache[file_path]
del self.prefetch_cache[file_path]
self.file_status[file_path] = "loaded"
if tensor is not None:
# Ensure tensor is on correct device
if target_device != "cpu" and tensor.device.type == "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
# If not in cache, load directly from disk
try:
if not os.path.exists(file_path):
logger.error(f"File not found for loading: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.file_status[file_path] = "loaded"
if target_device != "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
except Exception as e:
logger.error(f"Error loading tensor from {file_path}: {e}")
raise
def _safe_delete_file(self, file_path):
"""Safely delete a file with proper synchronization"""
with self.manager_lock:
# Make sure any save operation is completed
if file_path in self.save_futures:
future = self.save_futures[file_path]
try:
if not future.done():
future.cancel()
del self.save_futures[file_path]
except FileNotFoundError as e:
logger.warning(
f"Error canceling save operation for {file_path}: {e}"
)
# Only delete if file exists and is not being prefetched
status = self.file_status.get(file_path)
if status in ["ready", "loaded", "error", "missing"]:
try:
if os.path.exists(file_path):
os.remove(file_path)
self.file_status[file_path] = "deleted"
return True
except FileNotFoundError as e:
logger.warning(f"Error deleting file {file_path}: {e}")
return False
def trigger_prefetch(self, n=None):
"""Trigger prefetching of the next N tensors with proper synchronization"""
if n is None:
n = self.max_prefetch
prefetch_paths = []
with self.manager_lock:
# Find files that are ready to be prefetched (not already in cache or being prefetched)
for path in reversed(self.tensor_paths):
if (
path not in self.prefetch_cache
and self.file_status.get(path) == "ready"
):
prefetch_paths.append(path)
if len(prefetch_paths) >= n:
break
# Queue files for prefetching
for path in prefetch_paths:
self.prefetch_queue.put(path)
def cleanup_tensor(self, file_path: str):
"""Clean up a specific tensor file after it's been used"""
with self.manager_lock:
if file_path in self.tensor_paths:
self.tensor_paths.remove(file_path)
# Remove from prefetch cache if present
if file_path in self.prefetch_cache:
del self.prefetch_cache[file_path]
# Remove from save futures if present
if file_path in self.save_futures:
future = self.save_futures[file_path]
if not future.done():
future.cancel()
del self.save_futures[file_path]
# Try to delete the file
self._safe_delete_file(file_path)
def cleanup(self):
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
self.stop_event.set()
# Cancel all pending save operations
with self.manager_lock:
for _, future in self.save_futures.items():
if not future.done():
future.cancel()
self.save_futures.clear()
# Drain the save queue
while not self.save_queue.empty():
try:
self.save_queue.get_nowait()
self.save_queue.task_done()
except queue.Empty:
break
# Shutdown the save pool
self.save_pool.shutdown(wait=False)
# Join the save worker thread
if self.save_worker.is_alive():
self.save_worker.join(timeout=2.0)
# Join the prefetch worker threads
for thread in self.prefetch_workers:
if thread.is_alive():
thread.join(timeout=2.0)
# Clear cache and remove all temporary files
with self.manager_lock:
self.prefetch_cache.clear()
paths_to_delete = list(self.tensor_paths)
self.tensor_paths.clear()
# Delete all temporary files
for path in paths_to_delete:
self._safe_delete_file(path)
# Remove temp directory
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except FileNotFoundError as e:
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
class Disco(torch.autograd.Function):
"""
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
Advanced disk-based gradient checkpointer with prefetching.
"""
# Shared manager instance across all checkpointing operations
_manager = None
@staticmethod
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
"""Get or create the offload manager"""
if Disco._manager is None:
Disco._manager = DiskOffloadManager(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
return Disco._manager
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(
ctx,
forward_function,
hidden_states,
*args,
prefetch_size=1,
prefetch_to_gpu=True,
save_workers=4,
):
"""Forward pass that offloads activations to disk asynchronously"""
# Get or create the manager
manager = Disco.get_instance(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
# Save tensor to disk asynchronously
file_path = manager.save_tensor(hidden_states)
# Run forward pass immediately without waiting for save to complete
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store what we need for backward
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, *grad_outputs):
"""Backward pass that loads activations from disk with prefetching"""
# Get the manager
manager = Disco._manager
# Trigger prefetching for future tensors
# This happens at the start of backward, so should have time to complete
manager.trigger_prefetch()
# Load hidden states from disk or prefetch cache
file_path = ctx.file_path
try:
# Ensure the file is saved before we try to load it
manager.wait_for_save(file_path)
hidden_states = manager.load_tensor(file_path)
hidden_states.requires_grad = True
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Handle tuple outputs properly
if isinstance(output, tuple):
if len(grad_outputs) == len(output):
torch.autograd.backward(output, grad_outputs)
else:
torch.autograd.backward(output, grad_outputs[0])
else:
torch.autograd.backward(output, grad_outputs[0])
# Clean up the file after we're done with it
manager.cleanup_tensor(file_path)
return (
(
None, # forward_function
hidden_states.grad, # hidden_states grad
)
+ (None,) * len(ctx.args) # for each arg
+ (
None, # prefetch_size
None, # prefetch_to_gpu
None, # save_workers
)
)
except Exception as e:
logger.error(f"Error in backward pass: {e}")
# Clean up the file even on error
manager.cleanup_tensor(file_path)
raise

View File

@@ -1,4 +1,4 @@
"""CPU offloaded checkpointing"""
"""Unsloth checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -70,10 +70,7 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -623,10 +620,6 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()

View File

@@ -83,6 +83,7 @@ class AxolotlInputConfig(
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
embeddings_skip_upcast: bool | None = None
random_init_weights: bool | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
@@ -178,7 +179,7 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None

View File

@@ -26,15 +26,10 @@ class TestActivationCheckpointing:
E2E tests for activation checkpointing
"""
@pytest.mark.parametrize(
"gradient_checkpointing",
["offload", "offload_disk"],
)
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -69,7 +64,7 @@ class TestActivationCheckpointing:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
"gradient_checkpointing": "offload",
}
)