monkeypatch fix for fsdp with cpu ram efficient loading (#3464) [skip ci]
This commit is contained in:
@@ -166,6 +166,13 @@ class PatchManager:
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config:
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||
patch_initialize_missing_keys_for_fsdp,
|
||||
)
|
||||
|
||||
patch_initialize_missing_keys_for_fsdp()
|
||||
|
||||
if self.cfg.context_parallel_size > 1 or (
|
||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||
):
|
||||
|
||||
@@ -479,6 +479,39 @@ def patch_tied_keys_for_meta_device():
|
||||
)
|
||||
|
||||
|
||||
def patch_initialize_missing_keys_for_fsdp():
|
||||
"""Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
|
||||
|
||||
When using cpu_ram_efficient_loading, non-rank-0 processes load weights on
|
||||
meta device and move them to CPU as empty tensors. Without this patch,
|
||||
initialize_weights() re-initializes ALL parameters (via guarded init
|
||||
functions), which is slow and uses extra RAM per process.
|
||||
|
||||
The fix marks all params/buffers with _is_hf_initialized=True before calling
|
||||
the original method, so guarded init functions (init.normal_, init.zeros_,
|
||||
etc.) become no-ops on non-rank-0 processes. The real weights arrive later
|
||||
via FSDP broadcast from rank 0.
|
||||
|
||||
Upstream fix: https://github.com/huggingface/transformers/pull/44473
|
||||
Remove this patch once transformers includes the fix in a stable release.
|
||||
"""
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
|
||||
|
||||
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
|
||||
|
||||
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
for key in self.state_dict():
|
||||
param_or_buffer = self.get_parameter_or_buffer(key)
|
||||
param_or_buffer._is_hf_initialized = True
|
||||
self._is_hf_initialized = True
|
||||
|
||||
_original_initialize_missing_keys(self, is_quantized)
|
||||
|
||||
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
|
||||
Reference in New Issue
Block a user