diff --git a/docs/gradient_checkpointing.qmd b/docs/gradient_checkpointing.qmd index 25a887999..54c53899c 100644 --- a/docs/gradient_checkpointing.qmd +++ b/docs/gradient_checkpointing.qmd @@ -1,5 +1,5 @@ --- -title: Gradient Checkpointing and Activation Offloading +title: Gradient Checkpointing, Activation Offloading, and Layer Offloading --- Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning @@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory. + +### Enabling Layer Offloading + +```yaml +layer_offloading: true +``` + +Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU +and streaming them back to GPU one layer at a time during the forward and backward passes. This is +particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the +trainable adapter weights stay on GPU permanently. + +During training, forward and backward hooks on each decoder layer handle the transfer automatically: + +- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is + prefetched asynchronously on a separate CUDA stream for overlap. +- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the + previous layer is prefetched. + +After each layer finishes, its frozen params are offloaded back to CPU pinned memory. + +This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory +is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth +that is kept on GPU at any given time. + +**Requirements:** + +- CUDA GPU (CPU-only training is not supported for this feature) +- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.) +- Best combined with LoRA/QLoRA where most parameters are frozen diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd index 624738de7..b180387ed 100644 --- a/docs/optimizations.qmd +++ b/docs/optimizations.qmd @@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled. - Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM. - Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd) +### Layer Offloading + +Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen. + +- **Config:** `layer_offloading: true` +- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading) + ### Cut Cross Entropy (CCE) Reduces VRAM usage by using an optimized cross-entropy loss calculation. diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 5752a0584..90c813927 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -508,6 +508,8 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["accelerator_config"] = AcceleratorConfig() def _configure_gradient_checkpointing(self, training_args_kwargs: dict): + if self.cfg.layer_offloading: + training_args_kwargs["layer_offloading"] = True if self.cfg.activation_offloading is True: # don't use the HF gradient checkpointing, manually wrap training_args_kwargs["gradient_checkpointing"] = False diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index cae9b7f27..8dc1a0239 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -34,6 +34,7 @@ from axolotl.core.trainers.mixins import ( ActivationOffloadingMixin, CheckpointSaveMixin, DistributedParallelMixin, + LayerOffloadingMixin, OptimizerMixin, PackingMixin, RngLoaderMixin, @@ -66,6 +67,7 @@ class AxolotlTrainer( OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, + LayerOffloadingMixin, ActivationOffloadingMixin, DistributedParallelMixin, Trainer, diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 5fced1692..241694e44 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -4,6 +4,7 @@ from .activation_checkpointing import ActivationOffloadingMixin from .checkpoints import CheckpointSaveMixin +from .layer_offloading import LayerOffloadingMixin from .distributed_parallel import DistributedParallelMixin from .optimizer import OptimizerMixin from .packing import PackingMixin diff --git a/src/axolotl/core/trainers/mixins/layer_offloading.py b/src/axolotl/core/trainers/mixins/layer_offloading.py new file mode 100644 index 000000000..83a9feff5 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/layer_offloading.py @@ -0,0 +1,304 @@ +""" +Trainer mixin for layer-wise parameter offloading to CPU. + +Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses +forward/backward hooks to stream them on/off GPU one layer at a time with CUDA +stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always. + +Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on + transfer stream), post-hook offloads layer N-1's frozen params. +Backward: same in reverse order. +""" + +import contextlib + +import torch +import torch.nn as nn +from transformers import Trainer + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]: + """Recursively search the model for the decoder layer ModuleList. + + Finds any ModuleList whose children have 'DecoderLayer' in their class name. + Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE + where layers are at model.language_model.layers). + """ + # BFS to find the first ModuleList containing decoder layers + queue = [model] + while queue: + m = queue.pop(0) + for _name, child in m.named_children(): + if isinstance(child, nn.ModuleList) and len(child) > 0: + first_type = type(child[0]).__name__ + if "DecoderLayer" in first_type or "TransformerBlock" in first_type: + layer_types = list({type(layer).__name__ for layer in child}) + return child, layer_types + else: + queue.append(child) + + return None, [] + + +def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]: + """Get all non-trainable parameters in a layer.""" + return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad] + + +class LayerOffloadManager: + """Manages offloading frozen decoder layer params to CPU and streaming + them back during forward/backward with CUDA stream overlap. + + Only frozen (requires_grad=False) parameters are offloaded. + Trainable parameters (LoRA weights, etc.) remain on GPU at all times. + """ + + def __init__( + self, + model: nn.Module, + num_prefetch: int = 1, + ): + self.model = model + self.num_prefetch = num_prefetch + self._hooks: list = [] + self._device = None + + # Find decoder layers + self.layers, layer_types = _find_decoder_layers(model) + if self.layers is None: + LOG.warning( + "LayerOffloadManager: no decoder layers found, offloading disabled" + ) + self.enabled = False + return + + self.enabled = True + self.n_layers = len(self.layers) + LOG.info( + f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})" + ) + + # Determine GPU device + for p in model.parameters(): + if p.device.type == "cuda": + self._device = p.device + break + if self._device is None: + LOG.warning("LayerOffloadManager: no CUDA parameters found") + self.enabled = False + return + + # Transfer stream for async prefetch + self._transfer_stream = torch.cuda.Stream(device=self._device) + + # Track which layers have their frozen params on GPU + self._on_gpu: set[int] = set(range(self.n_layers)) + + # Cache: frozen param references per layer (list of (name, param) tuples) + self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [ + _get_frozen_params(self.layers[i]) for i in range(self.n_layers) + ] + + # CPU storage: pinned tensors for each layer's frozen params + # Populated on first offload + self._cpu_data: list[dict[str, torch.Tensor]] = [ + {} for _ in range(self.n_layers) + ] + + # Offload all layers upfront + self._offload_all() + + # Release cached memory blocks back to the driver + torch.cuda.empty_cache() + + def _offload_all(self): + """Move all frozen params in all decoder layers to CPU.""" + mem_before = torch.cuda.memory_allocated(self._device) + for i in range(self.n_layers): + self._offload_layer(i) + mem_after = torch.cuda.memory_allocated(self._device) + freed = (mem_before - mem_after) / 1e6 + LOG.info( + f"Layer offloading: offloaded frozen params from {self.n_layers} layers, " + f"freed {freed:.0f} MB GPU memory" + ) + + def _offload_layer(self, idx: int): + """Move frozen params of layer idx to CPU pinned memory.""" + if idx not in self._on_gpu: + return + for name, param in self._frozen_params[idx]: + if param.device.type != "cuda": + continue + # Allocate pinned CPU tensor on first offload + if name not in self._cpu_data[idx]: + self._cpu_data[idx][name] = torch.empty_like( + param.data, device="cpu", pin_memory=True + ) + cpu_buf = self._cpu_data[idx][name] + # Async copy GPU -> CPU (on transfer stream for overlap) + cpu_buf.copy_(param.data, non_blocking=True) + # Point parameter at a dummy CPU tensor to free GPU memory + param.data = cpu_buf + self._on_gpu.discard(idx) + + def _load_layer(self, idx: int, stream=None): + """Move frozen params of layer idx back to GPU.""" + if idx in self._on_gpu or idx < 0 or idx >= self.n_layers: + return + ctx = ( + torch.cuda.stream(stream) + if stream is not None + else contextlib.nullcontext() + ) + with ctx: + for _name, param in self._frozen_params[idx]: + if param.device.type == "cuda": + continue + gpu_data = param.data.to(self._device, non_blocking=True) + param.data = gpu_data + self._on_gpu.add(idx) + + def _prefetch_layer(self, idx: int): + """Async prefetch layer idx on the transfer stream.""" + if idx in self._on_gpu or idx < 0 or idx >= self.n_layers: + return + self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device)) + self._load_layer(idx, stream=self._transfer_stream) + + def _wait_transfer(self): + """Make default stream wait for any in-flight transfers.""" + torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream) + + def setup_hooks(self): + """Register forward and backward hooks on each decoder layer.""" + if not self.enabled: + return + + for idx in range(self.n_layers): + layer = self.layers[idx] + + def make_pre_fwd(i): + def hook(module, args): + # Ensure this layer is on GPU + if i not in self._on_gpu: + self._load_layer(i) + self._wait_transfer() + # Prefetch next layer(s) + for offset in range(1, self.num_prefetch + 1): + self._prefetch_layer(i + offset) + + return hook + + def make_post_fwd(i): + def hook(module, args, output): + # Offload previous layer (no longer needed in forward) + if i > 0: + self._offload_layer(i - 1) + # Offload last layer after forward + if i == self.n_layers - 1: + self._offload_layer(i) + + return hook + + def make_pre_bwd(i): + def hook(module, grad_output): + # Load this layer for backward + if i not in self._on_gpu: + self._load_layer(i) + self._wait_transfer() + # Prefetch previous layer(s) + for offset in range(1, self.num_prefetch + 1): + self._prefetch_layer(i - offset) + + return hook + + def make_post_bwd(i): + def hook(module, grad_input, grad_output): + # Offload the layer above + if i < self.n_layers - 1: + self._offload_layer(i + 1) + # Offload first layer after backward + if i == 0: + self._offload_layer(i) + + return hook + + h1 = layer.register_forward_pre_hook(make_pre_fwd(idx)) + h2 = layer.register_forward_hook(make_post_fwd(idx)) + h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx)) + h4 = layer.register_full_backward_hook(make_post_bwd(idx)) + self._hooks.extend([h1, h2, h3, h4]) + + def remove_hooks(self): + """Remove all hooks and restore layers to GPU.""" + for h in self._hooks: + h.remove() + self._hooks.clear() + if self.enabled: + for i in range(self.n_layers): + if i not in self._on_gpu: + self._load_layer(i) + + def pre_step(self): + """Called before each training step — ensure layers start offloaded.""" + if not self.enabled: + return + for i in list(self._on_gpu): + self._offload_layer(i) + # Prefetch layer 0 for forward + self._prefetch_layer(0) + + def post_step(self): + """Called after each training step — ensure layers are offloaded.""" + if not self.enabled: + return + for i in list(self._on_gpu): + self._offload_layer(i) + # Prefetch layer 0 for next step + self._prefetch_layer(0) + + +class _LayerOffloadContext: + """Context manager wrapping pre_step / post_step around a training step.""" + + def __init__(self, manager: LayerOffloadManager): + self.manager = manager + + def __enter__(self): + self.manager.pre_step() + return self + + def __exit__(self, *args): + self.manager.post_step() + + +class LayerOffloadingMixin(Trainer): + """ + Trainer mixin class for layer-wise parameter offloading to CPU. + + Offloads frozen decoder layer params to CPU at init, then streams them + on/off GPU one layer at a time during each training step. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if getattr(self.args, "layer_offloading", False): + LOG.info("Layer parameter offloading enabled") + self._layer_offload_manager = LayerOffloadManager( + model=self.model, + num_prefetch=1, + ) + self._layer_offload_manager.setup_hooks() + self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager) + else: + self._layer_offload_manager = None + self._layer_offload_ctx = contextlib.nullcontext() + + def training_step(self, *args, **kwargs): + with self._layer_offload_ctx: + return super().training_step(*args, **kwargs) diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 41ee8e91e..427a80a46 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -235,6 +235,13 @@ class AxolotlTrainingMixins: metadata={"help": "Use activation offloading with CUDA streams for training."}, ) + layer_offloading: bool | None = field( + default=None, + metadata={ + "help": "Offload model layer parameters to CPU during forward, prefetch back during backward." + }, + ) + # multi-modal section image_size: int | tuple[int, int] | None = field( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index a4eadf5cf..97a9c923e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -433,6 +433,12 @@ class AxolotlInputConfig( "description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'." }, ) + layer_offloading: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Offload model layer parameters to CPU during forward, prefetch back during backward." + }, + ) unfrozen_parameters: list[str] | None = Field( default=None,