fix the monkeypatch
This commit is contained in:
@@ -3,12 +3,11 @@ fix for zero3 8-bit lora
|
|||||||
see https://github.com/huggingface/transformers/pull/32943/files
|
see https://github.com/huggingface/transformers/pull/32943/files
|
||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
import transformers
|
from transformers import modeling_utils
|
||||||
import transformers.modeling_utils
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora")
|
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
|
||||||
|
|
||||||
ORIGINAL_LOAD_CODE = """
|
ORIGINAL_LOAD_CODE = """
|
||||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||||
@@ -38,7 +37,7 @@ PATCHED_LOAD_CODE = """
|
|||||||
|
|
||||||
def get_modeling_state_dict_code() -> str:
|
def get_modeling_state_dict_code() -> str:
|
||||||
load_code = inspect.getsource(
|
load_code = inspect.getsource(
|
||||||
transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
|
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
return load_code
|
return load_code
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ def patch_modeling_state_dict_code():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
load_code = get_modeling_state_dict_code()
|
load_code = get_modeling_state_dict_code()
|
||||||
transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
|
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
|
||||||
load_code
|
load_code
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
@@ -69,7 +68,7 @@ def patch_modeling_state_dict_code():
|
|||||||
)
|
)
|
||||||
|
|
||||||
items_to_import = []
|
items_to_import = []
|
||||||
for item in dir(transformers.modeling_utils):
|
for item in dir(modeling_utils):
|
||||||
if item in load_code:
|
if item in load_code:
|
||||||
items_to_import.append(item)
|
items_to_import.append(item)
|
||||||
|
|
||||||
@@ -80,5 +79,5 @@ def patch_modeling_state_dict_code():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
|
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True)
|
LOG.info("patching _load_state_dict_into_meta_model")
|
||||||
transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
|
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821
|
||||||
|
|||||||
Reference in New Issue
Block a user