patch for ds_grads_remaining in deepspeed (#3102) [skip ci]
* patch deepspeed * deepspeed patch for ds_grads_remaining * patch in Patchmanager * chore: lint * deepseed utils * chore2 * patch ds_grads_remaining chore * chore lint * chore lint * remove torch.nn patch * lint * Update src/axolotl/monkeypatch/utils.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * patched with checkpointwarapper * lint * only apply deepspeed patch when using activation offloading --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
Applies pre- and post-model load patches for various fixes and optimizations.
|
Applies pre- and post-model load patches for various fixes and optimizations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
@@ -66,6 +67,7 @@ class PatchManager:
|
|||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_fsdp2_bnb_patches()
|
self._apply_fsdp2_bnb_patches()
|
||||||
|
self._apply_patch_deepspeed_zero3()
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
@@ -471,3 +473,16 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||||
|
|
||||||
apply_lora_kernel_patches(model=model, cfg=self.cfg)
|
apply_lora_kernel_patches(model=model, cfg=self.cfg)
|
||||||
|
|
||||||
|
def _apply_patch_deepspeed_zero3(self):
|
||||||
|
try:
|
||||||
|
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
|
||||||
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
if self.cfg.activation_offloading is True and (
|
||||||
|
is_deepspeed_zero3_enabled()
|
||||||
|
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||||
|
):
|
||||||
|
apply_deepspeed_patches()
|
||||||
|
except ImportError as e:
|
||||||
|
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
||||||
|
|||||||
66
src/axolotl/monkeypatch/deepspeed_utils.py
Normal file
66
src/axolotl/monkeypatch/deepspeed_utils.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_checkpoint_wrapper_setattr():
|
||||||
|
"""
|
||||||
|
Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.
|
||||||
|
|
||||||
|
This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes
|
||||||
|
(like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed
|
||||||
|
ZeRO-3 to fail when gradient checkpointing is enabled.
|
||||||
|
|
||||||
|
This issue occurs specifically with:
|
||||||
|
- QLoRA + DeepSpeed ZeRO-3
|
||||||
|
- gradient_checkpointing: true
|
||||||
|
- activation_offloading: true
|
||||||
|
|
||||||
|
References:
|
||||||
|
- https://github.com/deepspeedai/DeepSpeed/issues/7203
|
||||||
|
- https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458
|
||||||
|
- https://github.com/axolotl-ai-cloud/axolotl/pull/3102
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
CheckpointWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if already patched
|
||||||
|
if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"):
|
||||||
|
LOG.debug("CheckpointWrapper already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
original_setattr = CheckpointWrapper.__setattr__
|
||||||
|
|
||||||
|
def new_setattr(self, name: str, value) -> None:
|
||||||
|
if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"):
|
||||||
|
setattr(self._checkpoint_wrapped_module, name, value)
|
||||||
|
LOG.debug(
|
||||||
|
f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
original_setattr(self, name, value)
|
||||||
|
|
||||||
|
CheckpointWrapper.__setattr__ = new_setattr
|
||||||
|
CheckpointWrapper._axolotl_setattr_patched = True
|
||||||
|
|
||||||
|
LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
LOG.debug(f"CheckpointWrapper not available: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to patch CheckpointWrapper: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_deepspeed_patches():
|
||||||
|
"""
|
||||||
|
Apply DeepSpeed-related patches
|
||||||
|
"""
|
||||||
|
if importlib.util.find_spec("deepspeed") is not None:
|
||||||
|
patch_checkpoint_wrapper_setattr()
|
||||||
|
else:
|
||||||
|
LOG.debug("DeepSpeed not available, skipping patches")
|
||||||
Reference in New Issue
Block a user