fix the monkeypatch

This commit is contained in:
Wing Lian
2024-11-19 02:12:33 -05:00
parent 1ff78d6347
commit afb8218c67

View File

@@ -3,12 +3,11 @@ fix for zero3 8-bit lora
see https://github.com/huggingface/transformers/pull/32943/files see https://github.com/huggingface/transformers/pull/32943/files
""" """
import inspect import inspect
import logging
import transformers from transformers import modeling_utils
import transformers.modeling_utils
from accelerate.logging import get_logger
LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora") LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
ORIGINAL_LOAD_CODE = """ ORIGINAL_LOAD_CODE = """
if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
@@ -38,7 +37,7 @@ PATCHED_LOAD_CODE = """
def get_modeling_state_dict_code() -> str: def get_modeling_state_dict_code() -> str:
load_code = inspect.getsource( load_code = inspect.getsource(
transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
) )
return load_code return load_code
@@ -54,7 +53,7 @@ def patch_modeling_state_dict_code():
""" """
load_code = get_modeling_state_dict_code() load_code = get_modeling_state_dict_code()
transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
load_code load_code
) )
assert ( assert (
@@ -69,7 +68,7 @@ def patch_modeling_state_dict_code():
) )
items_to_import = [] items_to_import = []
for item in dir(transformers.modeling_utils): for item in dir(modeling_utils):
if item in load_code: if item in load_code:
items_to_import.append(item) items_to_import.append(item)
@@ -80,5 +79,5 @@ def patch_modeling_state_dict_code():
globals(), globals(),
) )
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102 exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True) LOG.info("patching _load_state_dict_into_meta_model")
transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821 modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821