From 99187cd2082594fb51eefc5d5fe36eca33088829 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 14 Jul 2025 20:10:20 -0400 Subject: [PATCH] Activation Offloading w CUDA Streams (#2900) [skip ci] * use cuda streams for activation offloading * use torch native ops * update cfg schema for streams * fix literal constructor for set * use context for training step so it doesn't affect evals * disable streams * auto gc on eval steps * use activation_offloading config arg * add docs for gradient checkpointing * handle validation for gc/ao * use cuda streams for act offloading * add more validation for AC w/o GC * fix docs * move activation_offloading lower in definition so it doesn't break args/kwargs * fix kd due to import order --- _quarto.yml | 1 + docs/gradient_checkpointing.qmd | 29 ++++ src/axolotl/core/builders/base.py | 6 +- src/axolotl/core/trainers/base.py | 2 + src/axolotl/core/trainers/mixins/__init__.py | 1 + .../mixins/activation_checkpointing.py | 37 +++++ src/axolotl/core/training_args_base.py | 5 + src/axolotl/loaders/model.py | 10 ++ src/axolotl/loaders/patch_manager.py | 28 +--- .../gradient_checkpointing/__init__.py | 1 - .../gradient_checkpointing/offload_cpu.py | 157 ------------------ src/axolotl/utils/callbacks/__init__.py | 28 +++- src/axolotl/utils/schemas/config.py | 13 +- src/axolotl/utils/schemas/validation.py | 22 +++ 14 files changed, 154 insertions(+), 186 deletions(-) create mode 100644 docs/gradient_checkpointing.qmd create mode 100644 src/axolotl/core/trainers/mixins/activation_checkpointing.py diff --git a/_quarto.yml b/_quarto.yml index 93141aa9e..3e773a748 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -276,6 +276,7 @@ website: - docs/torchao.qmd - docs/custom_integrations.qmd - docs/sequence_parallelism.qmd + - docs/gradient_checkpointing.qmd - section: "Troubleshooting" contents: diff --git a/docs/gradient_checkpointing.qmd b/docs/gradient_checkpointing.qmd new file mode 100644 index 000000000..25a887999 --- /dev/null +++ b/docs/gradient_checkpointing.qmd @@ -0,0 +1,29 @@ +--- +title: Gradient Checkpointing and Activation Offloading +--- + +Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning +models by reducing the memory footprint and improving computational efficiency. + +### Enabling Gradient Checkpointing + +```yaml +gradient_checkpointing: true +``` + +### Enabling Activation Offloading + +```yaml +gradient_checkpointing: true # required for activation offloading +activation_offloading: true +``` + +Activation offloading variants: + +The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams +to overlap the communications and computations when offloading. + +The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations. + +For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads +activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory. diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 8ded23661..e80e905b8 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -434,7 +434,11 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config def _configure_gradient_checkpointing(self, training_args_kwargs: dict): - if self.cfg.gradient_checkpointing: + if self.cfg.activation_offloading is True: + # don't use the HF gradient checkpointing, manually wrap + training_args_kwargs["gradient_checkpointing"] = False + training_args_kwargs["activation_offloading"] = True + elif self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing ) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 6b2d30709..b983f1076 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length from typing_extensions import override from axolotl.core.trainers.mixins import ( + ActivationOffloadingMixin, CheckpointSaveMixin, OptimizerMixin, PackingMixin, @@ -48,6 +49,7 @@ class AxolotlTrainer( OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, + ActivationOffloadingMixin, Trainer, ): """Extend the base Trainer for axolotl helpers""" diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index b73b51126..453810aac 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -3,6 +3,7 @@ # pylint: disable=unused-import # flake8: noqa +from .activation_checkpointing import ActivationOffloadingMixin from .checkpoints import CheckpointSaveMixin from .optimizer import OptimizerMixin from .packing import PackingMixin diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py new file mode 100644 index 000000000..9488186cd --- /dev/null +++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py @@ -0,0 +1,37 @@ +""" +Trainer mixin for activation checkpointing w offloading +""" + +import contextlib + +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from transformers import GradientCheckpointingLayer, Trainer +from trl.models.activation_offloading import get_act_offloading_ctx_manager + + +class ActivationOffloadingMixin(Trainer): + """ + Trainer mixin class for activation checkpointing w offloading + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.args.activation_offloading: + self.activation_offload_context = get_act_offloading_ctx_manager( + self.model, use_streams=True + ) + else: + self.activation_offload_context = contextlib.nullcontext() + + def training_step(self, *args, **kwargs): + with self.activation_offload_context: + return super().training_step(*args, **kwargs) + + +def ac_wrap_hf_model(model: nn.Module, **kwargs): + auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,))) + apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs) diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 2e1987e82..4b74676ce 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -217,6 +217,11 @@ class AxolotlTrainingMixins: }, ) + activation_offloading: bool | None = field( + default=None, + metadata={"help": "Use activation offloading with CUDA streams for training."}, + ) + # multi-modal section image_size: int | tuple[int, int] | None = field( diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 03678e1b4..1ce98ef31 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -198,12 +198,22 @@ class ModelLoader: ): self.model = self.model.merge_and_unload() + self._apply_activation_checkpointing() self._resize_token_embeddings() self._adjust_model_config() self._configure_embedding_dtypes() self._configure_qat() log_gpu_memory_usage(LOG, "Memory usage after model load", 0) + def _apply_activation_checkpointing(self): + if self.cfg.activation_offloading is True: + from axolotl.core.trainers.mixins.activation_checkpointing import ( + ac_wrap_hf_model, + ) + + # ^^ importing this at the module level breaks plugins + ac_wrap_hf_model(self.model) + def _resize_token_embeddings(self): """Resize token embeddings if needed.""" embeddings_len = ( diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 2544429e6..84e6b33de 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -7,7 +7,6 @@ import importlib.util from functools import cached_property import addict -import torch import transformers from transformers import PretrainedConfig, PreTrainedModel @@ -168,28 +167,19 @@ class PatchManager: def _apply_gradient_checkpointing_patches(self): """Apply patches for gradient checkpointing.""" - if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: + if ( + self.cfg.gradient_checkpointing + and self.cfg.activation_offloading == "legacy" + ): from axolotl.monkeypatch.gradient_checkpointing import ( - CheckpointFunctionWithCPUOffload, hf_grad_checkpoint_offload_wrapper, ) - if ( - self.cfg.gradient_checkpointing_kwargs - and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs - and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False - ): - transformers.modeling_utils.checkpoint = ( - hf_grad_checkpoint_offload_wrapper - ) - else: - transformers.modeling_utils.checkpoint.CheckpointFunction = ( - CheckpointFunctionWithCPUOffload - ) - torch.utils.checkpoint.CheckpointFunction = ( - CheckpointFunctionWithCPUOffload - ) - if self.cfg.gradient_checkpointing == "offload_disk": + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + elif ( + self.cfg.gradient_checkpointing + and self.cfg.activation_offloading == "offload_disk" + ): from axolotl.monkeypatch.gradient_checkpointing import ( hf_grad_checkpoint_disk_offload_wrapper, ) diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py index 6ca8e0240..3b090d5e5 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py @@ -6,7 +6,6 @@ from functools import partial from packaging import version from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401 - CheckpointFunctionWithCPUOffload, CPU_Offloaded_Gradient_Checkpointer, ) from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py index 432cafb35..bbcfb91e6 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py @@ -14,18 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import inspect import torch from packaging import version from torch.utils.checkpoint import ( - _get_autocast_kwargs, - _get_device_module, - _infer_device_type, - check_backward_validity, - detach_variable, - get_device_states, set_device_states, ) @@ -76,153 +69,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name ) + ( None, ) * len(ctx.args) - - -# Copyright 2025 Snowflake Inc. -# SPDX-License-Identifier: Apache-2.0 -# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py -class CheckpointFunctionWithCPUOffload(torch.autograd.Function): - """ - This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)` - In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate. - """ - - @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): - check_backward_validity(args) - ctx.run_function = run_function - ctx.preserve_rng_state = preserve_rng_state - # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - ctx.device_type = _infer_device_type(*args) - ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( - ctx.device_type - ) - if preserve_rng_state: - ctx.fwd_cpu_state = torch.get_rng_state() - # Don't eagerly initialize the cuda context by accident. - # (If the user intends that the context is initialized later, within their - # run_function, we SHOULD actually stash the cuda state here. Unfortunately, - # we have no way to anticipate this will happen before we run the function.) - ctx.had_device_in_fwd = False - device_module = _get_device_module(ctx.device_type) - if getattr(device_module, "_initialized", False): - ctx.had_device_in_fwd = True - ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) - - # Save non-tensor inputs in ctx, keep a placeholder None for tensors - # to be filled out during the backward. - ctx.inputs = [] - ctx.tensor_indices = [] - tensor_inputs = [] - # x = None - for i, arg in enumerate(args): - if torch.is_tensor(arg): - # cpu-offload - # we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq] - # upstream could accept a list of arg indices to offload - if i == 0: - # print(f"{arg.shape=}") - ctx.x_device = arg.device - ctx.x_requires_grad = arg.requires_grad - t = arg.detach().cpu() - else: - t = arg - tensor_inputs.append(t) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) - - return outputs - - @staticmethod - def backward(ctx, *args): - if ( - not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access - ): - raise RuntimeError( - "When use_reentrant=True, torch.utils.checkpoint is incompatible" - " with .grad() or passing an `inputs` parameter to .backward()." - " To resolve this error, you can either set use_reentrant=False," - " or call .backward() without passing the `inputs` argument." - ) - # Copy the list to avoid modifying original list. - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - - # Fill in inputs with appropriate saved tensors. - for i, idx in enumerate(tensor_indices): - if i == 0: - t = ( - tensors[i] - .to(ctx.x_device) - .detach() - .requires_grad_(ctx.x_requires_grad) - ) - else: - t = tensors[i] - inputs[idx] = t - - # Stash the surrounding rng state, and mimic the state that was - # present at this time during forward. Restore the surrounding state - # when we're done. - rng_devices = [] - if ctx.preserve_rng_state and ctx.had_device_in_fwd: - rng_devices = ctx.fwd_devices - with torch.random.fork_rng( - devices=rng_devices, - enabled=ctx.preserve_rng_state, - device_type=ctx.device_type, - ): - if ctx.preserve_rng_state: - torch.set_rng_state(ctx.fwd_cpu_state) - if ctx.had_device_in_fwd: - if has_device_type: - # newer pytorch (as early as 2.7) - set_device_states( - ctx.fwd_devices, - ctx.fwd_device_states, - device_type=ctx.device_type, - ) - else: - # older pytorch (at least 2.4) - set_device_states(ctx.fwd_devices, ctx.fwd_device_states) - detached_inputs = detach_variable(tuple(inputs)) - - device_autocast_ctx = ( - torch.amp.autocast( - device_type=ctx.device_type, **ctx.device_autocast_kwargs - ) - if torch.amp.is_autocast_available(ctx.device_type) - else contextlib.nullcontext() - ) - with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - - # run backward() with only tensor that requires grad - outputs_with_grad = [] - args_with_grad = [] - for i in range(len(outputs)): # pylint: disable=consider-using-enumerate - if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: - outputs_with_grad.append(outputs[i]) - args_with_grad.append(args[i]) - if len(outputs_with_grad) == 0: - raise RuntimeError( - "none of output has requires_grad=True, this checkpoint() is not necessary" - ) - torch.autograd.backward(outputs_with_grad, args_with_grad) - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs - ) - - return (None, None) + grads diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 2a93ceef5..5f804d6af 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -841,21 +841,35 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): class GCCallback(TrainerCallback): """Callback to garbage collect torch cache""" - def __init__(self, gc_steps=None): - self.gc_steps = gc_steps + def __init__(self, gc_steps: int | None = -1): + self.gc_steps: int = gc_steps or -1 + self.next_gc_on_begin_step: int = -1 + + def _gc(self): + torch.cuda.empty_cache() + gc.collect() + + def on_step_begin( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + if self.next_gc_on_begin_step == state.global_step: + self._gc() def on_step_end( self, args, state, control, **kwargs # pylint: disable=unused-argument ): - if self.gc_steps > 0 and state.global_step % self.gc_steps == 0: - torch.cuda.empty_cache() - gc.collect() + if control.should_evaluate: + # automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer + self._gc() + # also GC on the start of the next step after the eval + self.next_gc_on_begin_step = state.global_step + 1 + elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0: + self._gc() def on_epoch_end( self, args, state, control, **kwargs # pylint: disable=unused-argument ): - torch.cuda.empty_cache() - gc.collect() + self._gc() def colab_inference_post_train_callback(trainer: Trainer): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6668380bf..f757cc5b0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -320,7 +320,12 @@ class AxolotlInputConfig( }, ) - gc_steps: int | None = None + gc_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)." + }, + ) bf16: Literal["auto"] | bool | None = Field( default="auto", @@ -360,6 +365,12 @@ class AxolotlInputConfig( "description": "Additional kwargs to pass to the trainer for gradient checkpointing" }, ) + activation_offloading: Literal["legacy", "disk"] | bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'." + }, + ) unfrozen_parameters: list[str] | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index bf2bc9070..db3fd0a1c 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1017,6 +1017,28 @@ class ModelCompatibilityValidationMixin: self.gradient_checkpointing = "offload" return self + @model_validator(mode="after") + def check_gradient_checkpointing_w_offload(self): + if self.gradient_checkpointing == "offload": + LOG.warning( + "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`" + ) + self.gradient_checkpointing = True + self.activation_offloading = True + if self.gradient_checkpointing == "offload_disk": + LOG.warning( + "`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`" + ) + self.gradient_checkpointing = True + self.activation_offloading = "disk" + return self + + @model_validator(mode="after") + def check_activation_offloading_wo_gc(self): + if self.activation_offloading and not self.gradient_checkpointing: + raise ValueError("activation_offloading requires gradient_checkpointing") + return self + @model_validator(mode="after") def check_better_transformers(self): if self.flash_optimum is True: