monkey-patch transformers to simplify monkey-patching modeling code (#1877)

* monkey-patch transformers so that monkey-patched modeling code doesnt get overwritten

* unnecessary now

* add comment
This commit is contained in:
Aman Gupta Karmani
2024-08-27 17:22:26 -07:00
committed by GitHub
parent 1e43660701
commit 159b8b9a74
3 changed files with 56 additions and 2 deletions

View File

@@ -94,5 +94,3 @@ def patch_remote(model_name, config_name, modeling_name):
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
# workaround to make the patch stick
modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access

View File

@@ -0,0 +1,51 @@
"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk"""
import importlib
import os
import sys
import typing
from pathlib import Path
from transformers.file_utils import HF_MODULES_CACHE
def _patched_get_class_in_module(
class_name: str, module_path: typing.Union[str, os.PathLike]
) -> typing.Type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_spec = importlib.util.spec_from_file_location(
name, location=Path(HF_MODULES_CACHE) / module_path
)
module = sys.modules.get(name)
if module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
# load in initial case only
module_spec.loader.exec_module(module)
return getattr(module, class_name)
def patch_transformers_dynamic_module_utils():
"""
Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True.
This causes monkey-patches for multipack and liger to be removed.
We replace the original function with a version that does not reload the module from disk.
See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581
"""
import transformers
transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module

View File

@@ -43,6 +43,9 @@ from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.monkeypatch.transformers_dynamic_module_utils import (
patch_transformers_dynamic_module_utils,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
@@ -54,6 +57,8 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod
LOG = logging.getLogger("axolotl")
patch_transformers_dynamic_module_utils()
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):