Compare commits
2 Commits
wait-distr
...
feat/wizar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22684ec98f | ||
|
|
6db60ac520 |
@@ -139,8 +139,7 @@ quartodoc:
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.gradient_checkpointing.offload_cpu
|
||||
- utils.gradient_checkpointing.offload_disk
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
|
||||
@@ -539,7 +539,7 @@ train_on_inputs: false
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
|
||||
@@ -342,6 +342,13 @@ def delinearize_llama4(model: str, output: str) -> None:
|
||||
do_delinearize_llama4(model, output)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def wizard():
|
||||
from axolotl.cli.wizard import do_wizard
|
||||
|
||||
do_wizard()
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
429
src/axolotl/cli/wizard.py
Normal file
429
src/axolotl/cli/wizard.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Wizard for creating yaml configs."""
|
||||
|
||||
import click
|
||||
import torch
|
||||
import yaml
|
||||
from packaging import version
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
|
||||
def do_wizard():
|
||||
print_axolotl_text_art()
|
||||
|
||||
# Ask where to save the config
|
||||
cfg = DictDefault({})
|
||||
config_path = click.prompt(
|
||||
"Where do you want to save the config?", type=str, default="config.yaml"
|
||||
)
|
||||
|
||||
# Ask base model
|
||||
base_model = click.prompt("What base model do you want to use?", type=str)
|
||||
cfg["base_model"] = base_model.strip()
|
||||
|
||||
# Ask whether want to enable Vision model
|
||||
# TODO: check if model has vision layers instead of asking user
|
||||
train_vision_model = click.confirm(
|
||||
"If this model has vision layers, do you want to train them?", default=False
|
||||
)
|
||||
|
||||
if train_vision_model:
|
||||
cfg["processor_type"] = "AutoProcessor"
|
||||
cfg["skip_prepare_dataset"] = True
|
||||
cfg["remove_unused_columns"] = False
|
||||
cfg["sample_packing"] = False
|
||||
|
||||
# Ask whether they want to set any advanced model features (custom tokenizer, custom config, etc)
|
||||
advanced_model_features = click.confirm(
|
||||
"Do you want to set any advanced model features? (custom tokenizer, custom config, remote code etc)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
if advanced_model_features:
|
||||
# Ask whether they want to use a custom config
|
||||
base_model_config = click.prompt(
|
||||
"What model config do you want to use? (leave blank for default)",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
|
||||
if base_model_config:
|
||||
cfg["base_model_config"] = base_model_config
|
||||
|
||||
# Ask whether they want to use a specific revision of the model
|
||||
revision_of_model = click.prompt(
|
||||
"What revision of the model do you want to use? (leave blank for default)",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
|
||||
if revision_of_model:
|
||||
cfg["revision_of_model"] = revision_of_model
|
||||
|
||||
# Ask whether they want to use a custom tokenizer
|
||||
tokenizer_config = click.prompt(
|
||||
"What tokenizer do you want to use? (leave blank for default)",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
|
||||
if tokenizer_config:
|
||||
cfg["tokenizer_config"] = tokenizer_config
|
||||
|
||||
# Ask whether they want to use remote code
|
||||
trust_remote_code = click.confirm(
|
||||
"Do you want to use remote code?", default=False
|
||||
)
|
||||
|
||||
if trust_remote_code:
|
||||
cfg["trust_remote_code"] = trust_remote_code
|
||||
|
||||
# Whether to resize token embeddings
|
||||
resize_token_embeddings_to_32x = click.confirm(
|
||||
"Do you want to resize token embeddings to 32x?", default=False
|
||||
)
|
||||
|
||||
if resize_token_embeddings_to_32x:
|
||||
cfg["resize_token_embeddings_to_32x"] = resize_token_embeddings_to_32x
|
||||
|
||||
# Whether to shrink embeddings to len(tokenizer)
|
||||
shrink_embeddings = click.confirm(
|
||||
"Do you want to shrink embeddings to len(tokenizer)?", default=False
|
||||
)
|
||||
|
||||
if shrink_embeddings:
|
||||
cfg["shrink_embeddings"] = shrink_embeddings
|
||||
|
||||
# Whether to skip upcast embeddings
|
||||
embeddings_skip_upcast = click.confirm(
|
||||
"Do you want to skip upcast embeddings?", default=False
|
||||
)
|
||||
|
||||
if embeddings_skip_upcast:
|
||||
cfg["embeddings_skip_upcast"] = embeddings_skip_upcast
|
||||
|
||||
# Whether to random init weights
|
||||
random_init_weights = click.confirm(
|
||||
"Do you want to random init weights?", default=False
|
||||
)
|
||||
|
||||
if random_init_weights:
|
||||
cfg["random_init_weights"] = random_init_weights
|
||||
|
||||
# Get model type
|
||||
config = load_model_config(cfg)
|
||||
model_type = config.model_type
|
||||
|
||||
# Ask sequence length
|
||||
sequence_length = click.prompt("What sequence length do you want to use?", type=int)
|
||||
cfg["sequence_length"] = sequence_length
|
||||
|
||||
# Whether to turn on sample packing
|
||||
if cfg["sample_packing"] is None:
|
||||
cfg["sample_packing"] = click.confirm(
|
||||
"Do you want to turn on sample packing? This will speed up training by packing multiple samples into a single batch.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
if cfg["sample_packing"]:
|
||||
cfg["pad_to_sequence_len"] = True
|
||||
|
||||
# Whether to turn off eval sample packing
|
||||
no_eval_sample_packing = click.confirm(
|
||||
"Do you want to turn off eval sample packing? This will slow down evaluation but is recommended if you are using a small validation set.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
if no_eval_sample_packing:
|
||||
cfg["eval_sample_packing"] = False
|
||||
|
||||
# Hardware check
|
||||
try:
|
||||
is_ampere_or_newer = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere_or_newer = False
|
||||
except AssertionError: # this is raised if no cuda is available
|
||||
is_ampere_or_newer = False
|
||||
|
||||
# Get num gpus
|
||||
try:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
except RuntimeError:
|
||||
num_gpus = 0
|
||||
|
||||
# Get torch version
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
is_torch_2_6_or_newer = version.parse(torch_version) >= version.parse("2.6.0")
|
||||
|
||||
# Whether to turn on attention
|
||||
opt = ["xformers", "sdp"]
|
||||
|
||||
if is_ampere_or_newer:
|
||||
opt.append("flash")
|
||||
|
||||
if is_torch_2_6_or_newer:
|
||||
opt.append("flex")
|
||||
|
||||
if cfg["sample_packing"]:
|
||||
if "flash" in opt:
|
||||
default_opt = "flash"
|
||||
elif "flex" in opt:
|
||||
default_opt = "flex"
|
||||
else:
|
||||
default_opt = opt[0]
|
||||
|
||||
attention = click.prompt(
|
||||
"Which attention backend do you want to use? Sample packing requires an attention backend to be set.",
|
||||
type=click.Choice(opt),
|
||||
default=default_opt,
|
||||
)
|
||||
else:
|
||||
# non-sample packing supports no attention and S2
|
||||
opt.extend(["none", "s2"])
|
||||
|
||||
attention = click.prompt(
|
||||
"Which attention backend do you want to use?",
|
||||
type=click.Choice(opt),
|
||||
default="none",
|
||||
)
|
||||
|
||||
if attention == "none":
|
||||
attention = None
|
||||
|
||||
# TODO: if xformers, check if FA is installed
|
||||
# TODO: flex doc mentioned requiring seq len to be divisible by 128. Unclear if limitation still exists
|
||||
|
||||
# TODO: requires #2489
|
||||
cfg["attention"] = attention
|
||||
|
||||
# Whether to turn on gradient checkpointing
|
||||
# TODO: need to wait for offload_disk PR to be merged
|
||||
gradient_checkpointing = click.prompt(
|
||||
"Which gradient checkpointing strategy do you want to use?",
|
||||
type=click.Choice(["none", "true", "offload", "offload_disk"]),
|
||||
default="true",
|
||||
)
|
||||
|
||||
if gradient_checkpointing == "none":
|
||||
gradient_checkpointing = False
|
||||
elif gradient_checkpointing == "true":
|
||||
gradient_checkpointing = True
|
||||
|
||||
# Ask whether to set use_reentrant
|
||||
# TODO: get correct defaults based on SFT/RL mode and single/multigpu
|
||||
# use_reentrant = click.confirm(
|
||||
# "Do you want to set use_reentrant?",
|
||||
# default=True,
|
||||
# )
|
||||
|
||||
# if use_reentrant:
|
||||
# cfg["use_reentrant"] = use_reentrant
|
||||
|
||||
# Optimizer
|
||||
cfg["optimizer"] = click.prompt(
|
||||
"Which optimizer do you want to use?",
|
||||
type=click.Choice((OptimizerNames | CustomSupportedOptimizers)),
|
||||
default=OptimizerNames.ADAMW_TORCH_FUSED,
|
||||
)
|
||||
|
||||
cfg["lr_scheduler"] = click.prompt(
|
||||
"Which learning rate scheduler do you want to use?",
|
||||
type=click.Choice(
|
||||
[
|
||||
"cosine",
|
||||
"one_cycle",
|
||||
"rex",
|
||||
"log_sweep",
|
||||
"linear",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
"inverse_sqrt",
|
||||
"reduce_lr_on_plateau",
|
||||
"cosine_with_min_lr",
|
||||
"warmup_stable_decay",
|
||||
]
|
||||
),
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
# Plugins
|
||||
|
||||
cfg["plugins"] = []
|
||||
|
||||
# Whether to turn on cut cross entropy
|
||||
if is_ampere_or_newer:
|
||||
# Note: This may error if users don't have CCE installed
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
cut_cross_entropy = click.confirm(
|
||||
"Do you want to turn on cut cross entropy? This will save VRAM if the model has a large vocab size.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
if cut_cross_entropy:
|
||||
cfg["plugins"].append(
|
||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
|
||||
)
|
||||
|
||||
cfg["cut_cross_entropy"] = True
|
||||
|
||||
use_liger_kernel = click.confirm(
|
||||
"Do you want to use the liger kernel? This will speed up training and save VRAM.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
if use_liger_kernel:
|
||||
cfg["plugins"].append("axolotl.integrations.liger.LigerPlugin")
|
||||
|
||||
cfg["liger_rope"] = click.confirm(
|
||||
"Do you want to enable liger rope?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
cfg["liger_rms_norm"] = click.confirm(
|
||||
"Do you want to enable liger rms norm?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
cfg["liger_glu_activation"] = click.confirm(
|
||||
"Do you want to enable liger glu activation?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
cfg["liger_layer_norm"] = click.confirm(
|
||||
"Do you want to enable liger layer norm?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
if cfg["cut_cross_entropy"] is not True:
|
||||
cfg["liger_fused_linear_cross_entropy"] = click.confirm(
|
||||
"Do you want to enable liger fused linear cross entropy?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
# TODO: lora kernels (but they auto enable via validator already)
|
||||
|
||||
# TODO: is there incompat between torch compile and liger?
|
||||
cfg["torch_compile"] = click.confirm(
|
||||
"Do you want to enable torch compile?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
# Multi-gpu
|
||||
if num_gpus > 1:
|
||||
# Ask whether to use DDP/Deepspeed/FSDP
|
||||
multi_gpu_mode = click.prompt(
|
||||
"Which multi-gpu mode do you want to use?",
|
||||
type=click.Choice(["ddp", "deepspeed", "fsdp"]),
|
||||
default="ddp",
|
||||
)
|
||||
|
||||
if multi_gpu_mode == "deepspeed":
|
||||
# Ask which deepspeed config to use
|
||||
cfg["deepspeed"] = click.prompt(
|
||||
"Which deepspeed config do you want to use? The higher the number, the more VRAM you will save, but the slower it will run.",
|
||||
type=click.Choice(
|
||||
[
|
||||
"zero1.json",
|
||||
"zero1_torch_compile.json",
|
||||
"zero2.json",
|
||||
"zero3.json",
|
||||
"zero3_bf16.json",
|
||||
"zero3_bf16_cpuoffload_all.json",
|
||||
"zero3_bf16_cpuoffload_params.json",
|
||||
]
|
||||
),
|
||||
default="zero1.json",
|
||||
)
|
||||
elif multi_gpu_mode == "fsdp":
|
||||
fsdp_version = click.prompt(
|
||||
"Which fsdp version do you want to use?",
|
||||
type=click.Choice([1, 2]),
|
||||
default=2,
|
||||
)
|
||||
|
||||
# TODO: Handle FSDP config
|
||||
|
||||
if fsdp_version == 1:
|
||||
cfg["fsdp"] = ["full_shard", "auto_wrap"]
|
||||
|
||||
# Ask which state dict type to use
|
||||
fsdp_state_dict_type = click.prompt(
|
||||
"Which fsdp state dict type do you want to use?",
|
||||
type=click.Choice(["FULL_STATE_DICT", "SHARDED_STATE_DICT"]),
|
||||
default="FULL_STATE_DICT",
|
||||
)
|
||||
|
||||
fsdp_offload_params = click.confirm(
|
||||
"Do you want to offload parameters?",
|
||||
default=True,
|
||||
)
|
||||
|
||||
# TODO: can we load the model class and auto pull a default for this?
|
||||
fsdp_transformer_layer_cls_to_wrap = click.prompt(
|
||||
"Which transformer layer class to wrap? It is usually the Decoder layer class.",
|
||||
type=str,
|
||||
)
|
||||
|
||||
# TODO: add other options
|
||||
|
||||
cfg["fsdp_config"] = {
|
||||
"state_dict_type": fsdp_state_dict_type,
|
||||
"offload_params": fsdp_offload_params,
|
||||
"transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
|
||||
}
|
||||
|
||||
elif fsdp_version == 2:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Training mode (sft or rl)
|
||||
training_mode = click.prompt(
|
||||
"Which training mode do you want to use?",
|
||||
type=click.Choice(["sft", "rl"]),
|
||||
default="sft",
|
||||
)
|
||||
|
||||
if training_mode == "rl":
|
||||
cfg["rl"] = click.prompt(
|
||||
"Which rl mode do you want to use?",
|
||||
type=click.Choice(["dpo", "ipo", "orpo", "kto", "grpo", "simpo"]),
|
||||
)
|
||||
|
||||
# TODO: handle RL options
|
||||
|
||||
# Whether to use adapter
|
||||
|
||||
# Get batch/grad accu
|
||||
|
||||
# Get learning rate
|
||||
|
||||
# Get weight decay
|
||||
|
||||
# Get max grad norm
|
||||
|
||||
# Get num train epochs
|
||||
|
||||
# Get warmup ratio
|
||||
|
||||
# Get save ratio
|
||||
|
||||
# Get eval ratio
|
||||
|
||||
# Get dataset config
|
||||
|
||||
# Load metric tracker
|
||||
|
||||
# Save config to yaml
|
||||
# TODO: improve output yaml formatting. Need to add comments to help separate sections
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(cfg.to_dict(), f, sort_keys=False)
|
||||
@@ -289,18 +289,16 @@ def save_trained_model(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
else:
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
|
||||
@@ -5,11 +5,8 @@ from functools import partial
|
||||
|
||||
from packaging import version
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||
Disco,
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
|
||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||
@@ -29,31 +26,12 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
else decoder_layer.__self__
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return Disco.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return Disco.apply(
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
|
||||
@@ -1,531 +0,0 @@
|
||||
"""
|
||||
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||
"""
|
||||
|
||||
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
# Setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiskOffloadManager:
|
||||
"""
|
||||
Manages offloaded tensors and handles prefetching in a separate thread.
|
||||
Includes synchronization to prevent race conditions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefetch_size: int = 3,
|
||||
prefetch_to_gpu: bool = True,
|
||||
save_workers: int = 4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prefetch_size: Maximum number of tensors to prefetch in the background.
|
||||
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
||||
save_workers: Maximum number of concurrent save operations.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
||||
|
||||
# Track tensor paths and their status
|
||||
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
||||
self.file_locks: Dict[str, threading.Lock] = (
|
||||
{}
|
||||
) # Maps file_path -> threading.Lock()
|
||||
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
||||
self.file_status: Dict[str, str] = {}
|
||||
|
||||
self.max_prefetch = prefetch_size
|
||||
self.prefetch_to_gpu = prefetch_to_gpu
|
||||
|
||||
# Thread synchronization
|
||||
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
||||
|
||||
# Prefetch queue and cache
|
||||
self.prefetch_queue: queue.Queue = queue.Queue()
|
||||
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
||||
|
||||
# Save queue and thread pool
|
||||
self.save_queue: queue.Queue = queue.Queue()
|
||||
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
||||
self.save_futures: Dict[str, Future] = {}
|
||||
self.save_semaphore = threading.Semaphore(
|
||||
save_workers * 2
|
||||
) # Limit concurrent save operations
|
||||
|
||||
# Start prefetch worker thread
|
||||
self.stop_event = threading.Event()
|
||||
# start multiple threads for prefetching
|
||||
self.prefetch_worker_count = 2
|
||||
self.prefetch_workers = []
|
||||
for _ in range(self.prefetch_worker_count):
|
||||
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
||||
worker.start()
|
||||
self.prefetch_workers.append(worker)
|
||||
|
||||
# Start save worker thread
|
||||
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
||||
self.save_worker.start()
|
||||
self.idx = 0
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _save_worker(self):
|
||||
"""Background thread that processes the save queue"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
save_item = self.save_queue.get(timeout=0.5)
|
||||
if save_item is None:
|
||||
continue
|
||||
|
||||
tensor, file_path = save_item
|
||||
|
||||
# Submit the save task to the thread pool
|
||||
future = self.save_pool.submit(
|
||||
self._save_tensor_to_disk, tensor, file_path
|
||||
)
|
||||
with self.manager_lock:
|
||||
self.save_futures[file_path] = future
|
||||
|
||||
self.save_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||
continue
|
||||
|
||||
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
||||
"""Actually save the tensor to disk"""
|
||||
try:
|
||||
# Save tensor to disk
|
||||
cpu_tensor = tensor.detach().cpu()
|
||||
torch.save(cpu_tensor, file_path)
|
||||
del cpu_tensor
|
||||
|
||||
with self.manager_lock:
|
||||
# Mark file as ready
|
||||
self.file_status[file_path] = "ready"
|
||||
|
||||
# Release semaphore
|
||||
self.save_semaphore.release()
|
||||
|
||||
return True
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"Error saving tensor to {file_path}: {e}")
|
||||
with self.manager_lock:
|
||||
self.file_status[file_path] = "error"
|
||||
|
||||
# Release semaphore
|
||||
self.save_semaphore.release()
|
||||
|
||||
return False
|
||||
|
||||
def _prefetch_worker(self):
|
||||
"""Background thread that loads tensors from disk ahead of time"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
file_path = self.prefetch_queue.get(timeout=0.5)
|
||||
if file_path is None:
|
||||
continue
|
||||
|
||||
# Check if file is available and not already in cache
|
||||
with self.manager_lock:
|
||||
if (
|
||||
file_path not in self.file_status
|
||||
or self.file_status[file_path] == "deleted"
|
||||
):
|
||||
self.prefetch_queue.task_done()
|
||||
if file_path in self.prefetch_cache:
|
||||
self.prefetch_queue.task_done()
|
||||
continue
|
||||
|
||||
# If file is still being saved, wait for it
|
||||
if (
|
||||
self.file_status[file_path] == "saving"
|
||||
and file_path in self.save_futures
|
||||
):
|
||||
# Re-queue this prefetch request with a little delay
|
||||
self.prefetch_queue.task_done()
|
||||
time.sleep(0.1)
|
||||
self.prefetch_queue.put(file_path)
|
||||
continue
|
||||
|
||||
# Mark file as being prefetched
|
||||
self.file_status[file_path] = "prefetching"
|
||||
|
||||
# Load tensor from disk and store in cache
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
if self.prefetch_to_gpu:
|
||||
tensor = torch.load(
|
||||
file_path,
|
||||
map_location=torch.device("cuda"),
|
||||
weights_only=True,
|
||||
)
|
||||
else:
|
||||
tensor = torch.load(file_path, weights_only=True)
|
||||
|
||||
with self.manager_lock:
|
||||
self.prefetch_cache[file_path] = tensor
|
||||
self.file_status[file_path] = "ready"
|
||||
else:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) != "deleted":
|
||||
logger.warning(
|
||||
f"Prefetch error: File not found {file_path}"
|
||||
)
|
||||
self.file_status[file_path] = "missing"
|
||||
|
||||
except FileNotFoundError as e:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) != "deleted":
|
||||
logger.warning(f"Prefetch error for {file_path}: {e}")
|
||||
self.file_status[file_path] = "error"
|
||||
|
||||
self.prefetch_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||
continue
|
||||
|
||||
def save_tensor(self, tensor: torch.Tensor):
|
||||
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
||||
# Generate unique file path
|
||||
self.idx += 1
|
||||
file_path: str = os.path.join(
|
||||
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
||||
)
|
||||
|
||||
with self.manager_lock:
|
||||
# Mark file as being saved
|
||||
self.file_locks[file_path] = threading.Lock()
|
||||
self.file_status[file_path] = "saving"
|
||||
# Add to history
|
||||
self.tensor_paths.append(file_path)
|
||||
|
||||
# Acquire semaphore to limit concurrent save operations
|
||||
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
||||
# Queue tensor for saving in background
|
||||
self.save_queue.put((tensor.detach(), file_path))
|
||||
|
||||
return file_path
|
||||
|
||||
def wait_for_save(self, file_path, timeout=None) -> None:
|
||||
"""Wait for a tensor to be saved to disk"""
|
||||
start_time = time.time()
|
||||
while timeout is None or time.time() - start_time < timeout:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) == "ready":
|
||||
return
|
||||
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
||||
return
|
||||
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
if future.done():
|
||||
return
|
||||
|
||||
# Small sleep to prevent CPU spinning
|
||||
time.sleep(0.01)
|
||||
|
||||
# Timeout
|
||||
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
||||
return
|
||||
|
||||
def load_tensor(self, file_path, target_device="cuda"):
|
||||
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
||||
# Wait for tensor to be saved if it's still in progress
|
||||
self.wait_for_save(file_path)
|
||||
|
||||
tensor = None
|
||||
|
||||
# Try to get from cache first
|
||||
with self.manager_lock:
|
||||
# Check if tensor is already in cache
|
||||
if file_path in self.prefetch_cache:
|
||||
tensor = self.prefetch_cache[file_path]
|
||||
del self.prefetch_cache[file_path]
|
||||
self.file_status[file_path] = "loaded"
|
||||
|
||||
if tensor is not None:
|
||||
# Ensure tensor is on correct device
|
||||
if target_device != "cpu" and tensor.device.type == "cpu":
|
||||
tensor = tensor.to(target_device, non_blocking=True)
|
||||
return tensor
|
||||
|
||||
# If not in cache, load directly from disk
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"File not found for loading: {file_path}")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
tensor = torch.load(file_path, weights_only=True)
|
||||
|
||||
with self.manager_lock:
|
||||
self.file_status[file_path] = "loaded"
|
||||
|
||||
if target_device != "cpu":
|
||||
tensor = tensor.to(target_device, non_blocking=True)
|
||||
|
||||
return tensor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tensor from {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def _safe_delete_file(self, file_path):
|
||||
"""Safely delete a file with proper synchronization"""
|
||||
with self.manager_lock:
|
||||
# Make sure any save operation is completed
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
try:
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
del self.save_futures[file_path]
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Error canceling save operation for {file_path}: {e}"
|
||||
)
|
||||
|
||||
# Only delete if file exists and is not being prefetched
|
||||
status = self.file_status.get(file_path)
|
||||
if status in ["ready", "loaded", "error", "missing"]:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
self.file_status[file_path] = "deleted"
|
||||
return True
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"Error deleting file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def trigger_prefetch(self, n=None):
|
||||
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
||||
if n is None:
|
||||
n = self.max_prefetch
|
||||
|
||||
prefetch_paths = []
|
||||
with self.manager_lock:
|
||||
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
||||
for path in reversed(self.tensor_paths):
|
||||
if (
|
||||
path not in self.prefetch_cache
|
||||
and self.file_status.get(path) == "ready"
|
||||
):
|
||||
prefetch_paths.append(path)
|
||||
if len(prefetch_paths) >= n:
|
||||
break
|
||||
|
||||
# Queue files for prefetching
|
||||
for path in prefetch_paths:
|
||||
self.prefetch_queue.put(path)
|
||||
|
||||
def cleanup_tensor(self, file_path: str):
|
||||
"""Clean up a specific tensor file after it's been used"""
|
||||
with self.manager_lock:
|
||||
if file_path in self.tensor_paths:
|
||||
self.tensor_paths.remove(file_path)
|
||||
|
||||
# Remove from prefetch cache if present
|
||||
if file_path in self.prefetch_cache:
|
||||
del self.prefetch_cache[file_path]
|
||||
|
||||
# Remove from save futures if present
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
del self.save_futures[file_path]
|
||||
|
||||
# Try to delete the file
|
||||
self._safe_delete_file(file_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
||||
self.stop_event.set()
|
||||
|
||||
# Cancel all pending save operations
|
||||
with self.manager_lock:
|
||||
for _, future in self.save_futures.items():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self.save_futures.clear()
|
||||
|
||||
# Drain the save queue
|
||||
while not self.save_queue.empty():
|
||||
try:
|
||||
self.save_queue.get_nowait()
|
||||
self.save_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Shutdown the save pool
|
||||
self.save_pool.shutdown(wait=False)
|
||||
|
||||
# Join the save worker thread
|
||||
if self.save_worker.is_alive():
|
||||
self.save_worker.join(timeout=2.0)
|
||||
|
||||
# Join the prefetch worker threads
|
||||
for thread in self.prefetch_workers:
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Clear cache and remove all temporary files
|
||||
with self.manager_lock:
|
||||
self.prefetch_cache.clear()
|
||||
paths_to_delete = list(self.tensor_paths)
|
||||
self.tensor_paths.clear()
|
||||
|
||||
# Delete all temporary files
|
||||
for path in paths_to_delete:
|
||||
self._safe_delete_file(path)
|
||||
|
||||
# Remove temp directory
|
||||
try:
|
||||
if os.path.exists(self.temp_dir):
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
||||
|
||||
|
||||
class Disco(torch.autograd.Function):
|
||||
"""
|
||||
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||
Advanced disk-based gradient checkpointer with prefetching.
|
||||
"""
|
||||
|
||||
# Shared manager instance across all checkpointing operations
|
||||
_manager = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
||||
"""Get or create the offload manager"""
|
||||
if Disco._manager is None:
|
||||
Disco._manager = DiskOffloadManager(
|
||||
prefetch_size=prefetch_size,
|
||||
prefetch_to_gpu=prefetch_to_gpu,
|
||||
save_workers=save_workers,
|
||||
)
|
||||
return Disco._manager
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
forward_function,
|
||||
hidden_states,
|
||||
*args,
|
||||
prefetch_size=1,
|
||||
prefetch_to_gpu=True,
|
||||
save_workers=4,
|
||||
):
|
||||
"""Forward pass that offloads activations to disk asynchronously"""
|
||||
# Get or create the manager
|
||||
manager = Disco.get_instance(
|
||||
prefetch_size=prefetch_size,
|
||||
prefetch_to_gpu=prefetch_to_gpu,
|
||||
save_workers=save_workers,
|
||||
)
|
||||
|
||||
# Save tensor to disk asynchronously
|
||||
file_path = manager.save_tensor(hidden_states)
|
||||
|
||||
# Run forward pass immediately without waiting for save to complete
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
# Store what we need for backward
|
||||
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||
ctx.file_path = file_path
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, *grad_outputs):
|
||||
"""Backward pass that loads activations from disk with prefetching"""
|
||||
# Get the manager
|
||||
manager = Disco._manager
|
||||
|
||||
# Trigger prefetching for future tensors
|
||||
# This happens at the start of backward, so should have time to complete
|
||||
manager.trigger_prefetch()
|
||||
|
||||
# Load hidden states from disk or prefetch cache
|
||||
file_path = ctx.file_path
|
||||
try:
|
||||
# Ensure the file is saved before we try to load it
|
||||
manager.wait_for_save(file_path)
|
||||
|
||||
hidden_states = manager.load_tensor(file_path)
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
# Compute gradients
|
||||
with torch.enable_grad():
|
||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||
|
||||
# Handle tuple outputs properly
|
||||
if isinstance(output, tuple):
|
||||
if len(grad_outputs) == len(output):
|
||||
torch.autograd.backward(output, grad_outputs)
|
||||
else:
|
||||
torch.autograd.backward(output, grad_outputs[0])
|
||||
else:
|
||||
torch.autograd.backward(output, grad_outputs[0])
|
||||
|
||||
# Clean up the file after we're done with it
|
||||
manager.cleanup_tensor(file_path)
|
||||
|
||||
return (
|
||||
(
|
||||
None, # forward_function
|
||||
hidden_states.grad, # hidden_states grad
|
||||
)
|
||||
+ (None,) * len(ctx.args) # for each arg
|
||||
+ (
|
||||
None, # prefetch_size
|
||||
None, # prefetch_to_gpu
|
||||
None, # save_workers
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in backward pass: {e}")
|
||||
# Clean up the file even on error
|
||||
manager.cleanup_tensor(file_path)
|
||||
raise
|
||||
@@ -1,4 +1,4 @@
|
||||
"""CPU offloaded checkpointing"""
|
||||
"""Unsloth checkpointing"""
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
@@ -26,7 +26,7 @@ else:
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
torch.autograd.Function
|
||||
):
|
||||
"""
|
||||
@@ -70,10 +70,7 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
is_main_process,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -623,10 +620,6 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_disk_offload_wrapper
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
@@ -83,6 +83,7 @@ class AxolotlInputConfig(
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: bool | None = None
|
||||
embeddings_skip_upcast: bool | None = None
|
||||
random_init_weights: bool | None = None
|
||||
|
||||
rl: RLType | None = None
|
||||
trl: TRLConfig | None = Field(
|
||||
@@ -178,7 +179,7 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: torch.dtype | None
|
||||
|
||||
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||
|
||||
@@ -26,15 +26,10 @@ class TestActivationCheckpointing:
|
||||
E2E tests for activation checkpointing
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_checkpointing",
|
||||
["offload", "offload_disk"],
|
||||
)
|
||||
def test_activation_checkpointing_offload(
|
||||
self,
|
||||
temp_dir,
|
||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||
gradient_checkpointing,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -69,7 +64,7 @@ class TestActivationCheckpointing:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"gradient_checkpointing": "offload",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user