patch for ds_grads_remaining in deepspeed (#3102) [skip ci]
* patch deepspeed * deepspeed patch for ds_grads_remaining * patch in Patchmanager * chore: lint * deepseed utils * chore2 * patch ds_grads_remaining chore * chore lint * chore lint * remove torch.nn patch * lint * Update src/axolotl/monkeypatch/utils.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * patched with checkpointwarapper * lint * only apply deepspeed patch when using activation offloading --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
Applies pre- and post-model load patches for various fixes and optimizations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
from functools import cached_property
|
||||
|
||||
@@ -66,6 +67,7 @@ class PatchManager:
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_fsdp2_bnb_patches()
|
||||
self._apply_patch_deepspeed_zero3()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -471,3 +473,16 @@ class PatchManager:
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
|
||||
apply_lora_kernel_patches(model=model, cfg=self.cfg)
|
||||
|
||||
def _apply_patch_deepspeed_zero3(self):
|
||||
try:
|
||||
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
if self.cfg.activation_offloading is True and (
|
||||
is_deepspeed_zero3_enabled()
|
||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||
):
|
||||
apply_deepspeed_patches()
|
||||
except ImportError as e:
|
||||
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
||||
|
||||
66
src/axolotl/monkeypatch/deepspeed_utils.py
Normal file
66
src/axolotl/monkeypatch/deepspeed_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_checkpoint_wrapper_setattr():
|
||||
"""
|
||||
Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.
|
||||
|
||||
This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes
|
||||
(like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed
|
||||
ZeRO-3 to fail when gradient checkpointing is enabled.
|
||||
|
||||
This issue occurs specifically with:
|
||||
- QLoRA + DeepSpeed ZeRO-3
|
||||
- gradient_checkpointing: true
|
||||
- activation_offloading: true
|
||||
|
||||
References:
|
||||
- https://github.com/deepspeedai/DeepSpeed/issues/7203
|
||||
- https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458
|
||||
- https://github.com/axolotl-ai-cloud/axolotl/pull/3102
|
||||
"""
|
||||
|
||||
try:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointWrapper,
|
||||
)
|
||||
|
||||
# Check if already patched
|
||||
if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"):
|
||||
LOG.debug("CheckpointWrapper already patched")
|
||||
return
|
||||
|
||||
original_setattr = CheckpointWrapper.__setattr__
|
||||
|
||||
def new_setattr(self, name: str, value) -> None:
|
||||
if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"):
|
||||
setattr(self._checkpoint_wrapped_module, name, value)
|
||||
LOG.debug(
|
||||
f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}"
|
||||
)
|
||||
else:
|
||||
original_setattr(self, name, value)
|
||||
|
||||
CheckpointWrapper.__setattr__ = new_setattr
|
||||
CheckpointWrapper._axolotl_setattr_patched = True
|
||||
|
||||
LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes")
|
||||
|
||||
except ImportError as e:
|
||||
LOG.debug(f"CheckpointWrapper not available: {e}")
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to patch CheckpointWrapper: {e}")
|
||||
|
||||
|
||||
def apply_deepspeed_patches():
|
||||
"""
|
||||
Apply DeepSpeed-related patches
|
||||
"""
|
||||
if importlib.util.find_spec("deepspeed") is not None:
|
||||
patch_checkpoint_wrapper_setattr()
|
||||
else:
|
||||
LOG.debug("DeepSpeed not available, skipping patches")
|
||||
Reference in New Issue
Block a user