From afb8218c67c6a09bf115bfe2d9124b9a9d861a89 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Nov 2024 02:12:33 -0500 Subject: [PATCH] fix the monkeypatch --- .../monkeypatch/modeling_zero3_int8_lora.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py b/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py index 29a09e347..c70a34b9a 100644 --- a/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py +++ b/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py @@ -3,12 +3,11 @@ fix for zero3 8-bit lora see https://github.com/huggingface/transformers/pull/32943/files """ import inspect +import logging -import transformers -import transformers.modeling_utils -from accelerate.logging import get_logger +from transformers import modeling_utils -LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora") +LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora") ORIGINAL_LOAD_CODE = """ if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): @@ -38,7 +37,7 @@ PATCHED_LOAD_CODE = """ def get_modeling_state_dict_code() -> str: load_code = inspect.getsource( - transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access + modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access ) return load_code @@ -54,7 +53,7 @@ def patch_modeling_state_dict_code(): """ load_code = get_modeling_state_dict_code() - transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access + modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access load_code ) assert ( @@ -69,7 +68,7 @@ def patch_modeling_state_dict_code(): ) items_to_import = [] - for item in dir(transformers.modeling_utils): + for item in dir(modeling_utils): if item in load_code: items_to_import.append(item) @@ -80,5 +79,5 @@ def patch_modeling_state_dict_code(): globals(), ) exec(load_code, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True) - transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821 + LOG.info("patching _load_state_dict_into_meta_model") + modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821