fix: add guard for _initialize_missing_keys patch (#3469) [skip ci]
This commit is contained in:
@@ -498,6 +498,9 @@ def patch_initialize_missing_keys_for_fsdp():
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
|
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
|
||||||
|
|
||||||
|
if getattr(PreTrainedModel._initialize_missing_keys, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
|
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
|
||||||
|
|
||||||
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
|
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
|
||||||
@@ -510,6 +513,7 @@ def patch_initialize_missing_keys_for_fsdp():
|
|||||||
_original_initialize_missing_keys(self, is_quantized)
|
_original_initialize_missing_keys(self, is_quantized)
|
||||||
|
|
||||||
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
|
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
|
||||||
|
PreTrainedModel._initialize_missing_keys._axolotl_patched = True
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
def patch_accelerate_fsdp2():
|
||||||
|
|||||||
Reference in New Issue
Block a user