monkeypatch fix for fsdp with cpu ram efficient loading (#3464) [skip ci]

This commit is contained in:
Wing Lian
2026-03-06 09:10:58 -05:00
committed by GitHub
parent 6c44afaea1
commit 56162f71db
2 changed files with 40 additions and 0 deletions

View File

@@ -166,6 +166,13 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config:
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_initialize_missing_keys_for_fsdp,
)
patch_initialize_missing_keys_for_fsdp()
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):

View File

@@ -479,6 +479,39 @@ def patch_tied_keys_for_meta_device():
)
def patch_initialize_missing_keys_for_fsdp():
"""Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
When using cpu_ram_efficient_loading, non-rank-0 processes load weights on
meta device and move them to CPU as empty tensors. Without this patch,
initialize_weights() re-initializes ALL parameters (via guarded init
functions), which is slow and uses extra RAM per process.
The fix marks all params/buffers with _is_hf_initialized=True before calling
the original method, so guarded init functions (init.normal_, init.zeros_,
etc.) become no-ops on non-rank-0 processes. The real weights arrive later
via FSDP broadcast from rank 0.
Upstream fix: https://github.com/huggingface/transformers/pull/44473
Remove this patch once transformers includes the fix in a stable release.
"""
from transformers import PreTrainedModel
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key in self.state_dict():
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
self._is_hf_initialized = True
_original_initialize_missing_keys(self, is_quantized)
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
def patch_accelerate_fsdp2():
import accelerate