From 81cc8368a3e42735e812e25bdf4dad0cff19ec01 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 23 Dec 2025 12:01:32 +0700 Subject: [PATCH] feat: cleaned up patches order --- src/axolotl/loaders/patch_manager.py | 20 +++ src/axolotl/loaders/tokenizer.py | 5 + .../models/kimi_linear/patch_kimi_linear.py | 117 +++++++----------- 3 files changed, 67 insertions(+), 75 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 569a071a4..cabfa0758 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -26,6 +26,26 @@ PLUGIN_MANAGER = PluginManager.get_instance() class PatchManager: """Manages the application of patches during the model loading process.""" + @staticmethod + def apply_pre_tokenizer_load_patches(cfg: DictDefault): + """ + Apply patches that must be set up before tokenizer loading. + This is for patches that intercept remote code loading from HuggingFace, + which needs to be in place before AutoTokenizer.from_pretrained() is called. + + Args: + cfg: Configuration dictionary with model and training settings. + """ + # Kimi-linear tokenizer patches need to be applied before tokenizer loading + # because the tokenizer uses remote code. + # Note: model_config_type is not set yet, so check base_model name + if hasattr(cfg, "base_model") and "kimi-linear" in cfg.base_model.lower(): + from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( + patch_kimi_tokenizer, + ) + + patch_kimi_tokenizer() + def __init__( self, cfg: DictDefault, diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index cf76577d3..c30ccdfaf 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -122,6 +122,11 @@ def modify_tokenizer_files( def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" + # Apply patches that need to be in place before tokenizer loading + from axolotl.loaders.patch_manager import PatchManager + + PatchManager.apply_pre_tokenizer_load_patches(cfg) + # if self.cfg.model_config_type == "kimi_linear": tokenizer_for_class_loading = AutoTokenizer.from_pretrained( cfg.tokenizer_config, trust_remote_code=True diff --git a/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py index b985b4853..f2f2bff4e 100644 --- a/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py +++ b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py @@ -1,9 +1,11 @@ import importlib.resources import importlib.util -import sys -from contextlib import contextmanager from pathlib import Path +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + def get_patch_file_path(package_dot_path: str, filename: str) -> Path: """ @@ -27,15 +29,19 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path: return None -def patch_kimi_model(): +def _patch_get_class_in_module(): """ - Apply Kimi model patches by hijacking Transformers' dynamic module loading. - This intercepts the remote code loading and replaces it with our local patches. + Core patch function that hijacks Transformers' dynamic module loading. + This is shared between tokenizer and model patching. """ from transformers.dynamic_module_utils import get_class_in_module KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear" + # Check if already patched to avoid double-patching + if hasattr(get_class_in_module, "_axolotl_patched"): + return + # Store original function original_get_class_in_module = get_class_in_module @@ -43,7 +49,6 @@ def patch_kimi_model(): """Patched version that returns our local modules instead of remote ones.""" # Check if this is a Kimi model module if "modeling_kimi" in module_path: - print("Intercepting remote Kimi modeling module, using local patch instead") # Load our local modeling_kimi.py instead patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py") if patch_path and patch_path.exists(): @@ -55,9 +60,6 @@ def patch_kimi_model(): return getattr(module, class_name) if "tokenization_kimi" in module_path: - print( - "Intercepting remote Kimi tokenizer module, using local patch instead" - ) # Load our local tokenization_kimi.py instead patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py") if patch_path and patch_path.exists(): @@ -75,17 +77,27 @@ def patch_kimi_model(): import transformers.dynamic_module_utils transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module + # Mark as patched to avoid double-patching + patched_get_class_in_module._axolotl_patched = True - # Also patch the resolve_trust_remote_code to handle auto_map + +def _patch_resolve_trust_remote_code(): + """ + Patch resolve_trust_remote_code to handle Kimi model auto_map. + This helps Transformers find our local modules instead of remote ones. + """ from transformers.dynamic_module_utils import resolve_trust_remote_code + # Check if already patched to avoid double-patching + if hasattr(resolve_trust_remote_code, "_axolotl_patched"): + return + original_resolve_trust_remote_code = resolve_trust_remote_code def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs): """Patched version to handle Kimi model auto_map.""" # Check if this is a Kimi model if "kimi" in repo_id.lower() or "kimi" in model_id.lower(): - print(f"Resolving trust remote code for Kimi model: {repo_id}") # Get the original result result = original_resolve_trust_remote_code( repo_id, model_id, *args, **kwargs @@ -107,76 +119,31 @@ def patch_kimi_model(): return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs) + import transformers.dynamic_module_utils + transformers.dynamic_module_utils.resolve_trust_remote_code = ( patched_resolve_trust_remote_code ) - - print("Kimi model patches applied successfully!") + patched_resolve_trust_remote_code._axolotl_patched = True -# The context manager code from before remains the same -@contextmanager -def patch_hf_imports(): +def patch_kimi_tokenizer(): """ - A context manager to temporarily inject custom modules into sys.modules. - - Args: - patch_map (dict): A dictionary mapping the target module name - (e.g., "modeling_falcon") to the local path of the - custom Python file. + Apply Kimi tokenizer patches. + This must be called BEFORE tokenizer loading to intercept remote code. """ + _patch_get_class_in_module() + _patch_resolve_trust_remote_code() + LOG.info("Kimi tokenizer patches applied successfully!") - KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear" - patches_to_apply = { - "modeling_kimi": "modeling_kimi.py", - "tokenization_kimi": "tokenization_kimi.py", - } - - patch_map = {} - for target_module, filename in patches_to_apply.items(): - patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename) - if patch_path and patch_path.exists(): - print(f"Found patch for '{target_module}' at '{patch_path}'") - patch_map[target_module] = patch_path - else: - raise FileNotFoundError( - f"Could not find the patch file '{filename}' " - f"in package '{KIMI_PATCH_PACKAGE}'" - ) - - original_modules = {} - injected_modules = [] - - for target_module_name, patch_file_path in patch_map.items(): - if not Path(patch_file_path).exists(): - print(f"Warning: Patch file not found at {patch_file_path}. Skipping.") - continue - - # If the original module is already loaded, save it for restoration - if target_module_name in sys.modules: - original_modules[target_module_name] = sys.modules[target_module_name] - - # Use importlib to load our custom file as a module - spec = importlib.util.spec_from_file_location( - target_module_name, patch_file_path - ) - custom_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(custom_module) - - # Inject it into sys.modules - sys.modules[target_module_name] = custom_module - injected_modules.append(target_module_name) - - try: - # Yield control back to the 'with' block - yield - finally: - # Cleanup: restore original modules or remove injected ones - for module_name in injected_modules: - if module_name in original_modules: - # Restore the original module if it existed - sys.modules[module_name] = original_modules[module_name] - else: - # Otherwise, just remove our injected module - del sys.modules[module_name] +def patch_kimi_model(): + """ + Apply Kimi model patches. + This is called during model loading. + Note: The core interception is already done by patch_kimi_tokenizer, + but we keep this for any model-specific patches that might be needed. + """ + _patch_get_class_in_module() + _patch_resolve_trust_remote_code() + LOG.info("Kimi model patches applied successfully!")