From 76aeb1615645176abbf92381db1d5a4c32bcedf5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Jul 2025 12:48:22 -0400 Subject: [PATCH] tiled_mlp supports single gpu (#2891) * tiled_mlp supports single gpu * use checkpoint offloading for arctic training * patch torch checkpoint too * support for single gpu zero3 * add linkback to where it was copied from --- docs/multi-gpu.qmd | 9 + src/axolotl/loaders/patch_manager.py | 18 +- .../gradient_checkpointing/__init__.py | 3 +- .../gradient_checkpointing/offload_cpu.py | 166 ++++++++++++++++++ src/axolotl/monkeypatch/tiled_mlp.py | 10 +- src/axolotl/utils/schemas/validation.py | 10 +- src/axolotl/utils/trainer.py | 9 + 7 files changed, 218 insertions(+), 7 deletions(-) diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index fee7d17e5..6dc198212 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -66,6 +66,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3. ::: +::: {.callout-tip} + +Using ZeRO Stage 3 with Single-GPU training + +ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables: +`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500` + +::: + ## FSDP {#sec-fsdp} ### Basic FSDP Configuration {#sec-fsdp-config} diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 48ee78cbc..81d4e9471 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -7,6 +7,7 @@ import importlib.util from functools import cached_property import addict +import torch import transformers from transformers import PretrainedConfig, PreTrainedModel @@ -165,10 +166,25 @@ class PatchManager: """Apply patches for gradient checkpointing.""" if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: from axolotl.monkeypatch.gradient_checkpointing import ( + CheckpointFunctionWithCPUOffload, hf_grad_checkpoint_offload_wrapper, ) - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if ( + self.cfg.gradient_checkpointing_kwargs + and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs + and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False + ): + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_offload_wrapper + ) + else: + transformers.modeling_utils.checkpoint.CheckpointFunction = ( + CheckpointFunctionWithCPUOffload + ) + torch.utils.checkpoint.CheckpointFunction = ( + CheckpointFunctionWithCPUOffload + ) if self.cfg.gradient_checkpointing == "offload_disk": from axolotl.monkeypatch.gradient_checkpointing import ( hf_grad_checkpoint_disk_offload_wrapper, diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py index 5d631776b..6ca8e0240 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py @@ -5,7 +5,8 @@ from functools import partial from packaging import version -from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401 + CheckpointFunctionWithCPUOffload, CPU_Offloaded_Gradient_Checkpointer, ) from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py index bbb5ad40d..432cafb35 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py @@ -13,8 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import contextlib +import inspect + import torch from packaging import version +from torch.utils.checkpoint import ( + _get_autocast_kwargs, + _get_device_module, + _infer_device_type, + check_backward_validity, + detach_variable, + get_device_states, + set_device_states, +) + +# support different pytorch versions +has_device_type = "device_type" in inspect.signature(set_device_states).parameters torch_version = version.parse(torch.__version__) @@ -60,3 +76,153 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name ) + ( None, ) * len(ctx.args) + + +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py +class CheckpointFunctionWithCPUOffload(torch.autograd.Function): + """ + This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)` + In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate. + """ + + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.device_type = _infer_device_type(*args) + ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( + ctx.device_type + ) + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_device_in_fwd = False + device_module = _get_device_module(ctx.device_type) + if getattr(device_module, "_initialized", False): + ctx.had_device_in_fwd = True + ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + # x = None + for i, arg in enumerate(args): + if torch.is_tensor(arg): + # cpu-offload + # we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq] + # upstream could accept a list of arg indices to offload + if i == 0: + # print(f"{arg.shape=}") + ctx.x_device = arg.device + ctx.x_requires_grad = arg.requires_grad + t = arg.detach().cpu() + else: + t = arg + tensor_inputs.append(t) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if ( + not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access + ): + raise RuntimeError( + "When use_reentrant=True, torch.utils.checkpoint is incompatible" + " with .grad() or passing an `inputs` parameter to .backward()." + " To resolve this error, you can either set use_reentrant=False," + " or call .backward() without passing the `inputs` argument." + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + if i == 0: + t = ( + tensors[i] + .to(ctx.x_device) + .detach() + .requires_grad_(ctx.x_requires_grad) + ) + else: + t = tensors[i] + inputs[idx] = t + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_device_in_fwd: + rng_devices = ctx.fwd_devices + with torch.random.fork_rng( + devices=rng_devices, + enabled=ctx.preserve_rng_state, + device_type=ctx.device_type, + ): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_device_in_fwd: + if has_device_type: + # newer pytorch (as early as 2.7) + set_device_states( + ctx.fwd_devices, + ctx.fwd_device_states, + device_type=ctx.device_type, + ) + else: + # older pytorch (at least 2.4) + set_device_states(ctx.fwd_devices, ctx.fwd_device_states) + detached_inputs = detach_variable(tuple(inputs)) + + device_autocast_ctx = ( + torch.amp.autocast( + device_type=ctx.device_type, **ctx.device_autocast_kwargs + ) + if torch.amp.is_autocast_available(ctx.device_type) + else contextlib.nullcontext() + ) + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): # pylint: disable=consider-using-enumerate + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True, this checkpoint() is not necessary" + ) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + return (None, None) + grads diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py index 4862ae78c..99a10df9c 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -1,6 +1,7 @@ """Monkeypatch for Tiled MLP implementation""" import math +import os import torch import torch.distributed as dist @@ -29,15 +30,18 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): mlp_forward = torch.compile(generic_mlp_forward) + is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1 + def tiled_mlp_forward(self, x): input_shape = x.shape seqlen = input_shape[-2] hidden = input_shape[-1] if cfg_num_shards is None: num_shards = math.ceil(seqlen / hidden) - num_shards_tensor = torch.tensor(num_shards, device=x.device) - dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) - num_shards = num_shards_tensor.item() + if is_distributed: + num_shards_tensor = torch.tensor(num_shards, device=x.device) + dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) + num_shards = num_shards_tensor.item() else: num_shards = cfg_num_shards diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index af1341cda..c2d4b4af4 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -479,8 +479,14 @@ class TrainingValidationMixin: @model_validator(mode="before") @classmethod def check_tiled_mlp_deepspeed(cls, data): - if data.get("tiled_mlp", False) and not data.get("deepspeed"): - raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled") + capabilities = data.get("capabilities") + n_gpu = 0 + if capabilities and capabilities.get("n_gpu", 0) >= 1: + n_gpu = capabilities.get("n_gpu", 0) + if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")): + raise ValueError( + "tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu" + ) return data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index cb597606c..2ef637232 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -546,6 +546,15 @@ def setup_deepspeed_env(cfg, stage=None): # NOTE(djsaunde): The distribued state cannot be initialized prior to the # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior # to model load. + if int(os.environ.get("WORLD_SIZE", "1")) == 1: + os.environ["WORLD_SIZE"] = "1" # force it in case not set + os.environ["LOCAL_RANK"] = "0" # force it in case not set + os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0") + import deepspeed.comm as dist + + dist.init_distributed( + dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True + ) init_distributed_state() # If we don't assign this, it doesn't actually get set in the accelerate weakref