feat: cleaned up patches order
This commit is contained in:
@@ -26,6 +26,26 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
class PatchManager:
|
||||
"""Manages the application of patches during the model loading process."""
|
||||
|
||||
@staticmethod
|
||||
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
|
||||
"""
|
||||
Apply patches that must be set up before tokenizer loading.
|
||||
This is for patches that intercept remote code loading from HuggingFace,
|
||||
which needs to be in place before AutoTokenizer.from_pretrained() is called.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary with model and training settings.
|
||||
"""
|
||||
# Kimi-linear tokenizer patches need to be applied before tokenizer loading
|
||||
# because the tokenizer uses remote code.
|
||||
# Note: model_config_type is not set yet, so check base_model name
|
||||
if hasattr(cfg, "base_model") and "kimi-linear" in cfg.base_model.lower():
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_tokenizer,
|
||||
)
|
||||
|
||||
patch_kimi_tokenizer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -122,6 +122,11 @@ def modify_tokenizer_files(
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
# Apply patches that need to be in place before tokenizer loading
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
|
||||
PatchManager.apply_pre_tokenizer_load_patches(cfg)
|
||||
|
||||
# if self.cfg.model_config_type == "kimi_linear":
|
||||
tokenizer_for_class_loading = AutoTokenizer.from_pretrained(
|
||||
cfg.tokenizer_config, trust_remote_code=True
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import importlib.resources
|
||||
import importlib.util
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||
"""
|
||||
@@ -27,15 +29,19 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||
return None
|
||||
|
||||
|
||||
def patch_kimi_model():
|
||||
def _patch_get_class_in_module():
|
||||
"""
|
||||
Apply Kimi model patches by hijacking Transformers' dynamic module loading.
|
||||
This intercepts the remote code loading and replaces it with our local patches.
|
||||
Core patch function that hijacks Transformers' dynamic module loading.
|
||||
This is shared between tokenizer and model patching.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import get_class_in_module
|
||||
|
||||
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||
|
||||
# Check if already patched to avoid double-patching
|
||||
if hasattr(get_class_in_module, "_axolotl_patched"):
|
||||
return
|
||||
|
||||
# Store original function
|
||||
original_get_class_in_module = get_class_in_module
|
||||
|
||||
@@ -43,7 +49,6 @@ def patch_kimi_model():
|
||||
"""Patched version that returns our local modules instead of remote ones."""
|
||||
# Check if this is a Kimi model module
|
||||
if "modeling_kimi" in module_path:
|
||||
print("Intercepting remote Kimi modeling module, using local patch instead")
|
||||
# Load our local modeling_kimi.py instead
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py")
|
||||
if patch_path and patch_path.exists():
|
||||
@@ -55,9 +60,6 @@ def patch_kimi_model():
|
||||
return getattr(module, class_name)
|
||||
|
||||
if "tokenization_kimi" in module_path:
|
||||
print(
|
||||
"Intercepting remote Kimi tokenizer module, using local patch instead"
|
||||
)
|
||||
# Load our local tokenization_kimi.py instead
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py")
|
||||
if patch_path and patch_path.exists():
|
||||
@@ -75,17 +77,27 @@ def patch_kimi_model():
|
||||
import transformers.dynamic_module_utils
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||
# Mark as patched to avoid double-patching
|
||||
patched_get_class_in_module._axolotl_patched = True
|
||||
|
||||
# Also patch the resolve_trust_remote_code to handle auto_map
|
||||
|
||||
def _patch_resolve_trust_remote_code():
|
||||
"""
|
||||
Patch resolve_trust_remote_code to handle Kimi model auto_map.
|
||||
This helps Transformers find our local modules instead of remote ones.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import resolve_trust_remote_code
|
||||
|
||||
# Check if already patched to avoid double-patching
|
||||
if hasattr(resolve_trust_remote_code, "_axolotl_patched"):
|
||||
return
|
||||
|
||||
original_resolve_trust_remote_code = resolve_trust_remote_code
|
||||
|
||||
def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs):
|
||||
"""Patched version to handle Kimi model auto_map."""
|
||||
# Check if this is a Kimi model
|
||||
if "kimi" in repo_id.lower() or "kimi" in model_id.lower():
|
||||
print(f"Resolving trust remote code for Kimi model: {repo_id}")
|
||||
# Get the original result
|
||||
result = original_resolve_trust_remote_code(
|
||||
repo_id, model_id, *args, **kwargs
|
||||
@@ -107,76 +119,31 @@ def patch_kimi_model():
|
||||
|
||||
return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs)
|
||||
|
||||
import transformers.dynamic_module_utils
|
||||
|
||||
transformers.dynamic_module_utils.resolve_trust_remote_code = (
|
||||
patched_resolve_trust_remote_code
|
||||
)
|
||||
|
||||
print("Kimi model patches applied successfully!")
|
||||
patched_resolve_trust_remote_code._axolotl_patched = True
|
||||
|
||||
|
||||
# The context manager code from before remains the same
|
||||
@contextmanager
|
||||
def patch_hf_imports():
|
||||
def patch_kimi_tokenizer():
|
||||
"""
|
||||
A context manager to temporarily inject custom modules into sys.modules.
|
||||
|
||||
Args:
|
||||
patch_map (dict): A dictionary mapping the target module name
|
||||
(e.g., "modeling_falcon") to the local path of the
|
||||
custom Python file.
|
||||
Apply Kimi tokenizer patches.
|
||||
This must be called BEFORE tokenizer loading to intercept remote code.
|
||||
"""
|
||||
_patch_get_class_in_module()
|
||||
_patch_resolve_trust_remote_code()
|
||||
LOG.info("Kimi tokenizer patches applied successfully!")
|
||||
|
||||
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||
|
||||
patches_to_apply = {
|
||||
"modeling_kimi": "modeling_kimi.py",
|
||||
"tokenization_kimi": "tokenization_kimi.py",
|
||||
}
|
||||
|
||||
patch_map = {}
|
||||
for target_module, filename in patches_to_apply.items():
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
|
||||
if patch_path and patch_path.exists():
|
||||
print(f"Found patch for '{target_module}' at '{patch_path}'")
|
||||
patch_map[target_module] = patch_path
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find the patch file '{filename}' "
|
||||
f"in package '{KIMI_PATCH_PACKAGE}'"
|
||||
)
|
||||
|
||||
original_modules = {}
|
||||
injected_modules = []
|
||||
|
||||
for target_module_name, patch_file_path in patch_map.items():
|
||||
if not Path(patch_file_path).exists():
|
||||
print(f"Warning: Patch file not found at {patch_file_path}. Skipping.")
|
||||
continue
|
||||
|
||||
# If the original module is already loaded, save it for restoration
|
||||
if target_module_name in sys.modules:
|
||||
original_modules[target_module_name] = sys.modules[target_module_name]
|
||||
|
||||
# Use importlib to load our custom file as a module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
target_module_name, patch_file_path
|
||||
)
|
||||
custom_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(custom_module)
|
||||
|
||||
# Inject it into sys.modules
|
||||
sys.modules[target_module_name] = custom_module
|
||||
injected_modules.append(target_module_name)
|
||||
|
||||
try:
|
||||
# Yield control back to the 'with' block
|
||||
yield
|
||||
finally:
|
||||
# Cleanup: restore original modules or remove injected ones
|
||||
for module_name in injected_modules:
|
||||
if module_name in original_modules:
|
||||
# Restore the original module if it existed
|
||||
sys.modules[module_name] = original_modules[module_name]
|
||||
else:
|
||||
# Otherwise, just remove our injected module
|
||||
del sys.modules[module_name]
|
||||
def patch_kimi_model():
|
||||
"""
|
||||
Apply Kimi model patches.
|
||||
This is called during model loading.
|
||||
Note: The core interception is already done by patch_kimi_tokenizer,
|
||||
but we keep this for any model-specific patches that might be needed.
|
||||
"""
|
||||
_patch_get_class_in_module()
|
||||
_patch_resolve_trust_remote_code()
|
||||
LOG.info("Kimi model patches applied successfully!")
|
||||
|
||||
Reference in New Issue
Block a user