Compare commits
1 Commits
testingci
...
v0.11.0.po
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c620a218b8 |
@@ -66,6 +66,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
Using ZeRO Stage 3 with Single-GPU training
|
||||||
|
|
||||||
|
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
|
||||||
|
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## FSDP {#sec-fsdp}
|
## FSDP {#sec-fsdp}
|
||||||
|
|
||||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
### Basic FSDP Configuration {#sec-fsdp-config}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import importlib.util
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
@@ -165,10 +166,25 @@ class PatchManager:
|
|||||||
"""Apply patches for gradient checkpointing."""
|
"""Apply patches for gradient checkpointing."""
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
|
CheckpointFunctionWithCPUOffload,
|
||||||
hf_grad_checkpoint_offload_wrapper,
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
if (
|
||||||
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
|
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
|
||||||
|
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
|
||||||
|
):
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_offload_wrapper
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformers.modeling_utils.checkpoint.CheckpointFunction = (
|
||||||
|
CheckpointFunctionWithCPUOffload
|
||||||
|
)
|
||||||
|
torch.utils.checkpoint.CheckpointFunction = (
|
||||||
|
CheckpointFunctionWithCPUOffload
|
||||||
|
)
|
||||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
hf_grad_checkpoint_disk_offload_wrapper,
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
|
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
|
||||||
|
CheckpointFunctionWithCPUOffload,
|
||||||
CPU_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
||||||
|
|||||||
@@ -13,8 +13,24 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
_get_autocast_kwargs,
|
||||||
|
_get_device_module,
|
||||||
|
_infer_device_type,
|
||||||
|
check_backward_validity,
|
||||||
|
detach_variable,
|
||||||
|
get_device_states,
|
||||||
|
set_device_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# support different pytorch versions
|
||||||
|
has_device_type = "device_type" in inspect.signature(set_device_states).parameters
|
||||||
|
|
||||||
torch_version = version.parse(torch.__version__)
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
@@ -60,3 +76,153 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
) + (
|
) + (
|
||||||
None,
|
None,
|
||||||
) * len(ctx.args)
|
) * len(ctx.args)
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright 2025 Snowflake Inc.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
|
||||||
|
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
|
||||||
|
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
|
check_backward_validity(args)
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.preserve_rng_state = preserve_rng_state
|
||||||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||||
|
ctx.device_type = _infer_device_type(*args)
|
||||||
|
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||||||
|
ctx.device_type
|
||||||
|
)
|
||||||
|
if preserve_rng_state:
|
||||||
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||||
|
# Don't eagerly initialize the cuda context by accident.
|
||||||
|
# (If the user intends that the context is initialized later, within their
|
||||||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||||
|
# we have no way to anticipate this will happen before we run the function.)
|
||||||
|
ctx.had_device_in_fwd = False
|
||||||
|
device_module = _get_device_module(ctx.device_type)
|
||||||
|
if getattr(device_module, "_initialized", False):
|
||||||
|
ctx.had_device_in_fwd = True
|
||||||
|
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
||||||
|
|
||||||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||||
|
# to be filled out during the backward.
|
||||||
|
ctx.inputs = []
|
||||||
|
ctx.tensor_indices = []
|
||||||
|
tensor_inputs = []
|
||||||
|
# x = None
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
# cpu-offload
|
||||||
|
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
|
||||||
|
# upstream could accept a list of arg indices to offload
|
||||||
|
if i == 0:
|
||||||
|
# print(f"{arg.shape=}")
|
||||||
|
ctx.x_device = arg.device
|
||||||
|
ctx.x_requires_grad = arg.requires_grad
|
||||||
|
t = arg.detach().cpu()
|
||||||
|
else:
|
||||||
|
t = arg
|
||||||
|
tensor_inputs.append(t)
|
||||||
|
ctx.tensor_indices.append(i)
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = run_function(*args)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
if (
|
||||||
|
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
|
||||||
|
" with .grad() or passing an `inputs` parameter to .backward()."
|
||||||
|
" To resolve this error, you can either set use_reentrant=False,"
|
||||||
|
" or call .backward() without passing the `inputs` argument."
|
||||||
|
)
|
||||||
|
# Copy the list to avoid modifying original list.
|
||||||
|
inputs = list(ctx.inputs)
|
||||||
|
tensor_indices = ctx.tensor_indices
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Fill in inputs with appropriate saved tensors.
|
||||||
|
for i, idx in enumerate(tensor_indices):
|
||||||
|
if i == 0:
|
||||||
|
t = (
|
||||||
|
tensors[i]
|
||||||
|
.to(ctx.x_device)
|
||||||
|
.detach()
|
||||||
|
.requires_grad_(ctx.x_requires_grad)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
t = tensors[i]
|
||||||
|
inputs[idx] = t
|
||||||
|
|
||||||
|
# Stash the surrounding rng state, and mimic the state that was
|
||||||
|
# present at this time during forward. Restore the surrounding state
|
||||||
|
# when we're done.
|
||||||
|
rng_devices = []
|
||||||
|
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||||||
|
rng_devices = ctx.fwd_devices
|
||||||
|
with torch.random.fork_rng(
|
||||||
|
devices=rng_devices,
|
||||||
|
enabled=ctx.preserve_rng_state,
|
||||||
|
device_type=ctx.device_type,
|
||||||
|
):
|
||||||
|
if ctx.preserve_rng_state:
|
||||||
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||||
|
if ctx.had_device_in_fwd:
|
||||||
|
if has_device_type:
|
||||||
|
# newer pytorch (as early as 2.7)
|
||||||
|
set_device_states(
|
||||||
|
ctx.fwd_devices,
|
||||||
|
ctx.fwd_device_states,
|
||||||
|
device_type=ctx.device_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# older pytorch (at least 2.4)
|
||||||
|
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||||
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
|
|
||||||
|
device_autocast_ctx = (
|
||||||
|
torch.amp.autocast(
|
||||||
|
device_type=ctx.device_type, **ctx.device_autocast_kwargs
|
||||||
|
)
|
||||||
|
if torch.amp.is_autocast_available(ctx.device_type)
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
)
|
||||||
|
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
||||||
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
|
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = (outputs,)
|
||||||
|
|
||||||
|
# run backward() with only tensor that requires grad
|
||||||
|
outputs_with_grad = []
|
||||||
|
args_with_grad = []
|
||||||
|
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
|
||||||
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||||
|
outputs_with_grad.append(outputs[i])
|
||||||
|
args_with_grad.append(args[i])
|
||||||
|
if len(outputs_with_grad) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"none of output has requires_grad=True, this checkpoint() is not necessary"
|
||||||
|
)
|
||||||
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||||
|
grads = tuple(
|
||||||
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
|
for inp in detached_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
return (None, None) + grads
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Monkeypatch for Tiled MLP implementation"""
|
"""Monkeypatch for Tiled MLP implementation"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -29,15 +30,18 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
|
|
||||||
mlp_forward = torch.compile(generic_mlp_forward)
|
mlp_forward = torch.compile(generic_mlp_forward)
|
||||||
|
|
||||||
|
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
||||||
|
|
||||||
def tiled_mlp_forward(self, x):
|
def tiled_mlp_forward(self, x):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
seqlen = input_shape[-2]
|
seqlen = input_shape[-2]
|
||||||
hidden = input_shape[-1]
|
hidden = input_shape[-1]
|
||||||
if cfg_num_shards is None:
|
if cfg_num_shards is None:
|
||||||
num_shards = math.ceil(seqlen / hidden)
|
num_shards = math.ceil(seqlen / hidden)
|
||||||
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
if is_distributed:
|
||||||
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
||||||
num_shards = num_shards_tensor.item()
|
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
||||||
|
num_shards = num_shards_tensor.item()
|
||||||
else:
|
else:
|
||||||
num_shards = cfg_num_shards
|
num_shards = cfg_num_shards
|
||||||
|
|
||||||
|
|||||||
@@ -479,8 +479,14 @@ class TrainingValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tiled_mlp_deepspeed(cls, data):
|
def check_tiled_mlp_deepspeed(cls, data):
|
||||||
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
|
capabilities = data.get("capabilities")
|
||||||
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
|
n_gpu = 0
|
||||||
|
if capabilities and capabilities.get("n_gpu", 0) >= 1:
|
||||||
|
n_gpu = capabilities.get("n_gpu", 0)
|
||||||
|
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
|
||||||
|
raise ValueError(
|
||||||
|
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -546,6 +546,15 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||||
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||||
# to model load.
|
# to model load.
|
||||||
|
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
|
||||||
|
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||||
|
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||||
|
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
|
||||||
|
import deepspeed.comm as dist
|
||||||
|
|
||||||
|
dist.init_distributed(
|
||||||
|
dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True
|
||||||
|
)
|
||||||
init_distributed_state()
|
init_distributed_state()
|
||||||
|
|
||||||
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
||||||
|
|||||||
Reference in New Issue
Block a user