feat: cleaned up patches order

This commit is contained in:
NanoCode012
2025-12-23 12:01:32 +07:00
parent 8a3cb223e6
commit 81cc8368a3
3 changed files with 67 additions and 75 deletions

View File

@@ -26,6 +26,26 @@ PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager:
"""Manages the application of patches during the model loading process."""
@staticmethod
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before tokenizer loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoTokenizer.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
# Kimi-linear tokenizer patches need to be applied before tokenizer loading
# because the tokenizer uses remote code.
# Note: model_config_type is not set yet, so check base_model name
if hasattr(cfg, "base_model") and "kimi-linear" in cfg.base_model.lower():
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_tokenizer,
)
patch_kimi_tokenizer()
def __init__(
self,
cfg: DictDefault,

View File

@@ -122,6 +122,11 @@ def modify_tokenizer_files(
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""
# Apply patches that need to be in place before tokenizer loading
from axolotl.loaders.patch_manager import PatchManager
PatchManager.apply_pre_tokenizer_load_patches(cfg)
# if self.cfg.model_config_type == "kimi_linear":
tokenizer_for_class_loading = AutoTokenizer.from_pretrained(
cfg.tokenizer_config, trust_remote_code=True

View File

@@ -1,9 +1,11 @@
import importlib.resources
import importlib.util
import sys
from contextlib import contextmanager
from pathlib import Path
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
"""
@@ -27,15 +29,19 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
return None
def patch_kimi_model():
def _patch_get_class_in_module():
"""
Apply Kimi model patches by hijacking Transformers' dynamic module loading.
This intercepts the remote code loading and replaces it with our local patches.
Core patch function that hijacks Transformers' dynamic module loading.
This is shared between tokenizer and model patching.
"""
from transformers.dynamic_module_utils import get_class_in_module
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
# Check if already patched to avoid double-patching
if hasattr(get_class_in_module, "_axolotl_patched"):
return
# Store original function
original_get_class_in_module = get_class_in_module
@@ -43,7 +49,6 @@ def patch_kimi_model():
"""Patched version that returns our local modules instead of remote ones."""
# Check if this is a Kimi model module
if "modeling_kimi" in module_path:
print("Intercepting remote Kimi modeling module, using local patch instead")
# Load our local modeling_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py")
if patch_path and patch_path.exists():
@@ -55,9 +60,6 @@ def patch_kimi_model():
return getattr(module, class_name)
if "tokenization_kimi" in module_path:
print(
"Intercepting remote Kimi tokenizer module, using local patch instead"
)
# Load our local tokenization_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py")
if patch_path and patch_path.exists():
@@ -75,17 +77,27 @@ def patch_kimi_model():
import transformers.dynamic_module_utils
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
# Mark as patched to avoid double-patching
patched_get_class_in_module._axolotl_patched = True
# Also patch the resolve_trust_remote_code to handle auto_map
def _patch_resolve_trust_remote_code():
"""
Patch resolve_trust_remote_code to handle Kimi model auto_map.
This helps Transformers find our local modules instead of remote ones.
"""
from transformers.dynamic_module_utils import resolve_trust_remote_code
# Check if already patched to avoid double-patching
if hasattr(resolve_trust_remote_code, "_axolotl_patched"):
return
original_resolve_trust_remote_code = resolve_trust_remote_code
def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs):
"""Patched version to handle Kimi model auto_map."""
# Check if this is a Kimi model
if "kimi" in repo_id.lower() or "kimi" in model_id.lower():
print(f"Resolving trust remote code for Kimi model: {repo_id}")
# Get the original result
result = original_resolve_trust_remote_code(
repo_id, model_id, *args, **kwargs
@@ -107,76 +119,31 @@ def patch_kimi_model():
return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs)
import transformers.dynamic_module_utils
transformers.dynamic_module_utils.resolve_trust_remote_code = (
patched_resolve_trust_remote_code
)
print("Kimi model patches applied successfully!")
patched_resolve_trust_remote_code._axolotl_patched = True
# The context manager code from before remains the same
@contextmanager
def patch_hf_imports():
def patch_kimi_tokenizer():
"""
A context manager to temporarily inject custom modules into sys.modules.
Args:
patch_map (dict): A dictionary mapping the target module name
(e.g., "modeling_falcon") to the local path of the
custom Python file.
Apply Kimi tokenizer patches.
This must be called BEFORE tokenizer loading to intercept remote code.
"""
_patch_get_class_in_module()
_patch_resolve_trust_remote_code()
LOG.info("Kimi tokenizer patches applied successfully!")
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
patches_to_apply = {
"modeling_kimi": "modeling_kimi.py",
"tokenization_kimi": "tokenization_kimi.py",
}
patch_map = {}
for target_module, filename in patches_to_apply.items():
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
if patch_path and patch_path.exists():
print(f"Found patch for '{target_module}' at '{patch_path}'")
patch_map[target_module] = patch_path
else:
raise FileNotFoundError(
f"Could not find the patch file '{filename}' "
f"in package '{KIMI_PATCH_PACKAGE}'"
)
original_modules = {}
injected_modules = []
for target_module_name, patch_file_path in patch_map.items():
if not Path(patch_file_path).exists():
print(f"Warning: Patch file not found at {patch_file_path}. Skipping.")
continue
# If the original module is already loaded, save it for restoration
if target_module_name in sys.modules:
original_modules[target_module_name] = sys.modules[target_module_name]
# Use importlib to load our custom file as a module
spec = importlib.util.spec_from_file_location(
target_module_name, patch_file_path
)
custom_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_module)
# Inject it into sys.modules
sys.modules[target_module_name] = custom_module
injected_modules.append(target_module_name)
try:
# Yield control back to the 'with' block
yield
finally:
# Cleanup: restore original modules or remove injected ones
for module_name in injected_modules:
if module_name in original_modules:
# Restore the original module if it existed
sys.modules[module_name] = original_modules[module_name]
else:
# Otherwise, just remove our injected module
del sys.modules[module_name]
def patch_kimi_model():
"""
Apply Kimi model patches.
This is called during model loading.
Note: The core interception is already done by patch_kimi_tokenizer,
but we keep this for any model-specific patches that might be needed.
"""
_patch_get_class_in_module()
_patch_resolve_trust_remote_code()
LOG.info("Kimi model patches applied successfully!")