patch for ds_grads_remaining in deepspeed (#3102) [skip ci]

* patch deepspeed

* deepspeed patch for ds_grads_remaining

* patch in Patchmanager

* chore: lint

* deepseed utils

* chore2

* patch ds_grads_remaining chore

* chore lint

* chore lint

* remove torch.nn patch

* lint

* Update src/axolotl/monkeypatch/utils.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patched with checkpointwarapper

* lint

* only apply deepspeed patch when using activation offloading

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2025-08-29 21:42:09 +05:30
committed by GitHub
parent 6afba3871d
commit 5b6ec2820f
2 changed files with 81 additions and 0 deletions

View File

@@ -3,6 +3,7 @@
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import os
import importlib.util
from functools import cached_property
@@ -66,6 +67,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -471,3 +473,16 @@ class PatchManager:
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)
def _apply_patch_deepspeed_zero3(self):
try:
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")

View File

@@ -0,0 +1,66 @@
import importlib
import importlib.util
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_checkpoint_wrapper_setattr():
"""
Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.
This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes
(like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed
ZeRO-3 to fail when gradient checkpointing is enabled.
This issue occurs specifically with:
- QLoRA + DeepSpeed ZeRO-3
- gradient_checkpointing: true
- activation_offloading: true
References:
- https://github.com/deepspeedai/DeepSpeed/issues/7203
- https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458
- https://github.com/axolotl-ai-cloud/axolotl/pull/3102
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
# Check if already patched
if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"):
LOG.debug("CheckpointWrapper already patched")
return
original_setattr = CheckpointWrapper.__setattr__
def new_setattr(self, name: str, value) -> None:
if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"):
setattr(self._checkpoint_wrapped_module, name, value)
LOG.debug(
f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}"
)
else:
original_setattr(self, name, value)
CheckpointWrapper.__setattr__ = new_setattr
CheckpointWrapper._axolotl_setattr_patched = True
LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes")
except ImportError as e:
LOG.debug(f"CheckpointWrapper not available: {e}")
except Exception as e:
LOG.warning(f"Failed to patch CheckpointWrapper: {e}")
def apply_deepspeed_patches():
"""
Apply DeepSpeed-related patches
"""
if importlib.util.find_spec("deepspeed") is not None:
patch_checkpoint_wrapper_setattr()
else:
LOG.debug("DeepSpeed not available, skipping patches")