fix the monkeypatch

This commit is contained in:
Wing Lian
2024-11-19 02:12:33 -05:00
parent 1ff78d6347
commit afb8218c67

View File

@@ -3,12 +3,11 @@ fix for zero3 8-bit lora
see https://github.com/huggingface/transformers/pull/32943/files
"""
import inspect
import logging
import transformers
import transformers.modeling_utils
from accelerate.logging import get_logger
from transformers import modeling_utils
LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora")
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
ORIGINAL_LOAD_CODE = """
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
@@ -38,7 +37,7 @@ PATCHED_LOAD_CODE = """
def get_modeling_state_dict_code() -> str:
load_code = inspect.getsource(
transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
)
return load_code
@@ -54,7 +53,7 @@ def patch_modeling_state_dict_code():
"""
load_code = get_modeling_state_dict_code()
transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
load_code
)
assert (
@@ -69,7 +68,7 @@ def patch_modeling_state_dict_code():
)
items_to_import = []
for item in dir(transformers.modeling_utils):
for item in dir(modeling_utils):
if item in load_code:
items_to_import.append(item)
@@ -80,5 +79,5 @@ def patch_modeling_state_dict_code():
globals(),
)
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True)
transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
LOG.info("patching _load_state_dict_into_meta_model")
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821