fix the monkeypatch
This commit is contained in:
@@ -3,12 +3,11 @@ fix for zero3 8-bit lora
|
||||
see https://github.com/huggingface/transformers/pull/32943/files
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate.logging import get_logger
|
||||
from transformers import modeling_utils
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora")
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
|
||||
|
||||
ORIGINAL_LOAD_CODE = """
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
@@ -38,7 +37,7 @@ PATCHED_LOAD_CODE = """
|
||||
|
||||
def get_modeling_state_dict_code() -> str:
|
||||
load_code = inspect.getsource(
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
|
||||
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
|
||||
)
|
||||
return load_code
|
||||
|
||||
@@ -54,7 +53,7 @@ def patch_modeling_state_dict_code():
|
||||
"""
|
||||
|
||||
load_code = get_modeling_state_dict_code()
|
||||
transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
|
||||
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
|
||||
load_code
|
||||
)
|
||||
assert (
|
||||
@@ -69,7 +68,7 @@ def patch_modeling_state_dict_code():
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.modeling_utils):
|
||||
for item in dir(modeling_utils):
|
||||
if item in load_code:
|
||||
items_to_import.append(item)
|
||||
|
||||
@@ -80,5 +79,5 @@ def patch_modeling_state_dict_code():
|
||||
globals(),
|
||||
)
|
||||
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True)
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
|
||||
LOG.info("patching _load_state_dict_into_meta_model")
|
||||
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
|
||||
|
||||
Reference in New Issue
Block a user