Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
c620a218b8 tiled_mlp supports single gpu (#2891)
Some checks failed
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* tiled_mlp supports single gpu

* use checkpoint offloading for arctic training

* patch torch checkpoint too

* support for single gpu zero3

* add linkback to where it was copied from
2025-07-09 12:48:51 -04:00
7 changed files with 218 additions and 7 deletions

View File

@@ -66,6 +66,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## FSDP {#sec-fsdp}
### Basic FSDP Configuration {#sec-fsdp-config}

View File

@@ -7,6 +7,7 @@ import importlib.util
from functools import cached_property
import addict
import torch
import transformers
from transformers import PretrainedConfig, PreTrainedModel
@@ -165,10 +166,25 @@ class PatchManager:
"""Apply patches for gradient checkpointing."""
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
from axolotl.monkeypatch.gradient_checkpointing import (
CheckpointFunctionWithCPUOffload,
hf_grad_checkpoint_offload_wrapper,
)
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if (
self.cfg.gradient_checkpointing_kwargs
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
):
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_offload_wrapper
)
else:
transformers.modeling_utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
torch.utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
if self.cfg.gradient_checkpointing == "offload_disk":
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,

View File

@@ -5,7 +5,8 @@ from functools import partial
from packaging import version
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
CheckpointFunctionWithCPUOffload,
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (

View File

@@ -13,8 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import torch
from packaging import version
from torch.utils.checkpoint import (
_get_autocast_kwargs,
_get_device_module,
_infer_device_type,
check_backward_validity,
detach_variable,
get_device_states,
set_device_states,
)
# support different pytorch versions
has_device_type = "device_type" in inspect.signature(set_device_states).parameters
torch_version = version.parse(torch.__version__)
@@ -60,3 +76,153 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
) + (
None,
) * len(ctx.args)
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
"""
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.device_type = _infer_device_type(*args)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
ctx.device_type
)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_device_in_fwd = False
device_module = _get_device_module(ctx.device_type)
if getattr(device_module, "_initialized", False):
ctx.had_device_in_fwd = True
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
# x = None
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# cpu-offload
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
# upstream could accept a list of arg indices to offload
if i == 0:
# print(f"{arg.shape=}")
ctx.x_device = arg.device
ctx.x_requires_grad = arg.requires_grad
t = arg.detach().cpu()
else:
t = arg
tensor_inputs.append(t)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if (
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
):
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
" To resolve this error, you can either set use_reentrant=False,"
" or call .backward() without passing the `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
if i == 0:
t = (
tensors[i]
.to(ctx.x_device)
.detach()
.requires_grad_(ctx.x_requires_grad)
)
else:
t = tensors[i]
inputs[idx] = t
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(
devices=rng_devices,
enabled=ctx.preserve_rng_state,
device_type=ctx.device_type,
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
if has_device_type:
# newer pytorch (as early as 2.7)
set_device_states(
ctx.fwd_devices,
ctx.fwd_device_states,
device_type=ctx.device_type,
)
else:
# older pytorch (at least 2.4)
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))
device_autocast_ctx = (
torch.amp.autocast(
device_type=ctx.device_type, **ctx.device_autocast_kwargs
)
if torch.amp.is_autocast_available(ctx.device_type)
else contextlib.nullcontext()
)
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None) + grads

View File

@@ -1,6 +1,7 @@
"""Monkeypatch for Tiled MLP implementation"""
import math
import os
import torch
import torch.distributed as dist
@@ -29,15 +30,18 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
mlp_forward = torch.compile(generic_mlp_forward)
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
if cfg_num_shards is None:
num_shards = math.ceil(seqlen / hidden)
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
if is_distributed:
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
else:
num_shards = cfg_num_shards

View File

@@ -479,8 +479,14 @@ class TrainingValidationMixin:
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
capabilities = data.get("capabilities")
n_gpu = 0
if capabilities and capabilities.get("n_gpu", 0) >= 1:
n_gpu = capabilities.get("n_gpu", 0)
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
raise ValueError(
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
)
return data

View File

@@ -546,6 +546,15 @@ def setup_deepspeed_env(cfg, stage=None):
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
import deepspeed.comm as dist
dist.init_distributed(
dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True
)
init_distributed_state()
# If we don't assign this, it doesn't actually get set in the accelerate weakref