From 5b6ec2820f26ce4b50c624b41453d93b2a9c6063 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 29 Aug 2025 21:42:09 +0530 Subject: [PATCH] patch for ds_grads_remaining in deepspeed (#3102) [skip ci] * patch deepspeed * deepspeed patch for ds_grads_remaining * patch in Patchmanager * chore: lint * deepseed utils * chore2 * patch ds_grads_remaining chore * chore lint * chore lint * remove torch.nn patch * lint * Update src/axolotl/monkeypatch/utils.py Co-authored-by: NanoCode012 * patched with checkpointwarapper * lint * only apply deepspeed patch when using activation offloading --------- Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/loaders/patch_manager.py | 15 +++++ src/axolotl/monkeypatch/deepspeed_utils.py | 66 ++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 src/axolotl/monkeypatch/deepspeed_utils.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 4959bd6ba..eafe89d29 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -3,6 +3,7 @@ Applies pre- and post-model load patches for various fixes and optimizations. """ +import os import importlib.util from functools import cached_property @@ -66,6 +67,7 @@ class PatchManager: self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() self._apply_fsdp2_bnb_patches() + self._apply_patch_deepspeed_zero3() def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" @@ -471,3 +473,16 @@ class PatchManager: from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches apply_lora_kernel_patches(model=model, cfg=self.cfg) + + def _apply_patch_deepspeed_zero3(self): + try: + from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + if self.cfg.activation_offloading is True and ( + is_deepspeed_zero3_enabled() + or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" + ): + apply_deepspeed_patches() + except ImportError as e: + LOG.warning(f"DeepSpeed patches not applied: {e}") diff --git a/src/axolotl/monkeypatch/deepspeed_utils.py b/src/axolotl/monkeypatch/deepspeed_utils.py new file mode 100644 index 000000000..6740f556b --- /dev/null +++ b/src/axolotl/monkeypatch/deepspeed_utils.py @@ -0,0 +1,66 @@ +import importlib +import importlib.util +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patch_checkpoint_wrapper_setattr(): + """ + Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules. + + This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes + (like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed + ZeRO-3 to fail when gradient checkpointing is enabled. + + This issue occurs specifically with: + - QLoRA + DeepSpeed ZeRO-3 + - gradient_checkpointing: true + - activation_offloading: true + + References: + - https://github.com/deepspeedai/DeepSpeed/issues/7203 + - https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458 + - https://github.com/axolotl-ai-cloud/axolotl/pull/3102 + """ + + try: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + ) + + # Check if already patched + if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"): + LOG.debug("CheckpointWrapper already patched") + return + + original_setattr = CheckpointWrapper.__setattr__ + + def new_setattr(self, name: str, value) -> None: + if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"): + setattr(self._checkpoint_wrapped_module, name, value) + LOG.debug( + f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}" + ) + else: + original_setattr(self, name, value) + + CheckpointWrapper.__setattr__ = new_setattr + CheckpointWrapper._axolotl_setattr_patched = True + + LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes") + + except ImportError as e: + LOG.debug(f"CheckpointWrapper not available: {e}") + except Exception as e: + LOG.warning(f"Failed to patch CheckpointWrapper: {e}") + + +def apply_deepspeed_patches(): + """ + Apply DeepSpeed-related patches + """ + if importlib.util.find_spec("deepspeed") is not None: + patch_checkpoint_wrapper_setattr() + else: + LOG.debug("DeepSpeed not available, skipping patches")