From 56162f71db17fbb665ad108e7dc701a79d9e066b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Mar 2026 09:10:58 -0500 Subject: [PATCH] monkeypatch fix for fsdp with cpu ram efficient loading (#3464) [skip ci] --- src/axolotl/loaders/patch_manager.py | 7 +++++ src/axolotl/monkeypatch/accelerate/fsdp2.py | 33 +++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 87520c06f..51bcaeba4 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -166,6 +166,13 @@ class PatchManager: def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" + if self.cfg.fsdp_config: + from axolotl.monkeypatch.accelerate.fsdp2 import ( + patch_initialize_missing_keys_for_fsdp, + ) + + patch_initialize_missing_keys_for_fsdp() + if self.cfg.context_parallel_size > 1 or ( self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2" ): diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index dd3deb19a..1fa589f07 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -479,6 +479,39 @@ def patch_tied_keys_for_meta_device(): ) +def patch_initialize_missing_keys_for_fsdp(): + """Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0. + + When using cpu_ram_efficient_loading, non-rank-0 processes load weights on + meta device and move them to CPU as empty tensors. Without this patch, + initialize_weights() re-initializes ALL parameters (via guarded init + functions), which is slow and uses extra RAM per process. + + The fix marks all params/buffers with _is_hf_initialized=True before calling + the original method, so guarded init functions (init.normal_, init.zeros_, + etc.) become no-ops on non-rank-0 processes. The real weights arrive later + via FSDP broadcast from rank 0. + + Upstream fix: https://github.com/huggingface/transformers/pull/44473 + Remove this patch once transformers includes the fix in a stable release. + """ + from transformers import PreTrainedModel + from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 + + _original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys + + def _patched_initialize_missing_keys(self, is_quantized: bool) -> None: + if is_fsdp_enabled() and not is_local_dist_rank_0(): + for key in self.state_dict(): + param_or_buffer = self.get_parameter_or_buffer(key) + param_or_buffer._is_hf_initialized = True + self._is_hf_initialized = True + + _original_initialize_missing_keys(self, is_quantized) + + PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys + + def patch_accelerate_fsdp2(): import accelerate