diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 529c42a8f..44fc4cb47 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -94,5 +94,3 @@ def patch_remote(model_name, config_name, modeling_name): module_name = model_config.__class__.__module__.replace(config_name, modeling_name) modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access - # workaround to make the patch stick - modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access diff --git a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py new file mode 100644 index 000000000..dfc3e29c5 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py @@ -0,0 +1,51 @@ +"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk""" + +import importlib +import os +import sys +import typing +from pathlib import Path + +from transformers.file_utils import HF_MODULES_CACHE + + +def _patched_get_class_in_module( + class_name: str, module_path: typing.Union[str, os.PathLike] +) -> typing.Type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_spec = importlib.util.spec_from_file_location( + name, location=Path(HF_MODULES_CACHE) / module_path + ) + module = sys.modules.get(name) + if module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + # load in initial case only + module_spec.loader.exec_module(module) + return getattr(module, class_name) + + +def patch_transformers_dynamic_module_utils(): + """ + Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True. + This causes monkey-patches for multipack and liger to be removed. + We replace the original function with a version that does not reload the module from disk. + See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581 + """ + import transformers + + transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e18330199..e0526fb04 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,6 +43,9 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.transformers_dynamic_module_utils import ( + patch_transformers_dynamic_module_utils, +) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -54,6 +57,8 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod LOG = logging.getLogger("axolotl") +patch_transformers_dynamic_module_utils() + # copied from accelerator.FullyShardedDataParallelPlugin def get_module_class_from_name(module, name):