diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 1fa589f07..87b537655 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -498,6 +498,9 @@ def patch_initialize_missing_keys_for_fsdp(): from transformers import PreTrainedModel from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 + if getattr(PreTrainedModel._initialize_missing_keys, "_axolotl_patched", False): + return + _original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys def _patched_initialize_missing_keys(self, is_quantized: bool) -> None: @@ -510,6 +513,7 @@ def patch_initialize_missing_keys_for_fsdp(): _original_initialize_missing_keys(self, is_quantized) PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys + PreTrainedModel._initialize_missing_keys._axolotl_patched = True def patch_accelerate_fsdp2():