feat: cleaned up patches order

This commit is contained in:
NanoCode012
2025-12-23 12:01:32 +07:00
parent 8a3cb223e6
commit 81cc8368a3
3 changed files with 67 additions and 75 deletions

View File

@@ -26,6 +26,26 @@ PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager: class PatchManager:
"""Manages the application of patches during the model loading process.""" """Manages the application of patches during the model loading process."""
@staticmethod
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before tokenizer loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoTokenizer.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
# Kimi-linear tokenizer patches need to be applied before tokenizer loading
# because the tokenizer uses remote code.
# Note: model_config_type is not set yet, so check base_model name
if hasattr(cfg, "base_model") and "kimi-linear" in cfg.base_model.lower():
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_tokenizer,
)
patch_kimi_tokenizer()
def __init__( def __init__(
self, self,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -122,6 +122,11 @@ def modify_tokenizer_files(
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config.""" """Load and configure the tokenizer based on the provided config."""
# Apply patches that need to be in place before tokenizer loading
from axolotl.loaders.patch_manager import PatchManager
PatchManager.apply_pre_tokenizer_load_patches(cfg)
# if self.cfg.model_config_type == "kimi_linear": # if self.cfg.model_config_type == "kimi_linear":
tokenizer_for_class_loading = AutoTokenizer.from_pretrained( tokenizer_for_class_loading = AutoTokenizer.from_pretrained(
cfg.tokenizer_config, trust_remote_code=True cfg.tokenizer_config, trust_remote_code=True

View File

@@ -1,9 +1,11 @@
import importlib.resources import importlib.resources
import importlib.util import importlib.util
import sys
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_patch_file_path(package_dot_path: str, filename: str) -> Path: def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
""" """
@@ -27,15 +29,19 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
return None return None
def patch_kimi_model(): def _patch_get_class_in_module():
""" """
Apply Kimi model patches by hijacking Transformers' dynamic module loading. Core patch function that hijacks Transformers' dynamic module loading.
This intercepts the remote code loading and replaces it with our local patches. This is shared between tokenizer and model patching.
""" """
from transformers.dynamic_module_utils import get_class_in_module from transformers.dynamic_module_utils import get_class_in_module
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear" KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
# Check if already patched to avoid double-patching
if hasattr(get_class_in_module, "_axolotl_patched"):
return
# Store original function # Store original function
original_get_class_in_module = get_class_in_module original_get_class_in_module = get_class_in_module
@@ -43,7 +49,6 @@ def patch_kimi_model():
"""Patched version that returns our local modules instead of remote ones.""" """Patched version that returns our local modules instead of remote ones."""
# Check if this is a Kimi model module # Check if this is a Kimi model module
if "modeling_kimi" in module_path: if "modeling_kimi" in module_path:
print("Intercepting remote Kimi modeling module, using local patch instead")
# Load our local modeling_kimi.py instead # Load our local modeling_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py") patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py")
if patch_path and patch_path.exists(): if patch_path and patch_path.exists():
@@ -55,9 +60,6 @@ def patch_kimi_model():
return getattr(module, class_name) return getattr(module, class_name)
if "tokenization_kimi" in module_path: if "tokenization_kimi" in module_path:
print(
"Intercepting remote Kimi tokenizer module, using local patch instead"
)
# Load our local tokenization_kimi.py instead # Load our local tokenization_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py") patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py")
if patch_path and patch_path.exists(): if patch_path and patch_path.exists():
@@ -75,17 +77,27 @@ def patch_kimi_model():
import transformers.dynamic_module_utils import transformers.dynamic_module_utils
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
# Mark as patched to avoid double-patching
patched_get_class_in_module._axolotl_patched = True
# Also patch the resolve_trust_remote_code to handle auto_map
def _patch_resolve_trust_remote_code():
"""
Patch resolve_trust_remote_code to handle Kimi model auto_map.
This helps Transformers find our local modules instead of remote ones.
"""
from transformers.dynamic_module_utils import resolve_trust_remote_code from transformers.dynamic_module_utils import resolve_trust_remote_code
# Check if already patched to avoid double-patching
if hasattr(resolve_trust_remote_code, "_axolotl_patched"):
return
original_resolve_trust_remote_code = resolve_trust_remote_code original_resolve_trust_remote_code = resolve_trust_remote_code
def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs): def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs):
"""Patched version to handle Kimi model auto_map.""" """Patched version to handle Kimi model auto_map."""
# Check if this is a Kimi model # Check if this is a Kimi model
if "kimi" in repo_id.lower() or "kimi" in model_id.lower(): if "kimi" in repo_id.lower() or "kimi" in model_id.lower():
print(f"Resolving trust remote code for Kimi model: {repo_id}")
# Get the original result # Get the original result
result = original_resolve_trust_remote_code( result = original_resolve_trust_remote_code(
repo_id, model_id, *args, **kwargs repo_id, model_id, *args, **kwargs
@@ -107,76 +119,31 @@ def patch_kimi_model():
return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs) return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs)
import transformers.dynamic_module_utils
transformers.dynamic_module_utils.resolve_trust_remote_code = ( transformers.dynamic_module_utils.resolve_trust_remote_code = (
patched_resolve_trust_remote_code patched_resolve_trust_remote_code
) )
patched_resolve_trust_remote_code._axolotl_patched = True
print("Kimi model patches applied successfully!")
# The context manager code from before remains the same def patch_kimi_tokenizer():
@contextmanager
def patch_hf_imports():
""" """
A context manager to temporarily inject custom modules into sys.modules. Apply Kimi tokenizer patches.
This must be called BEFORE tokenizer loading to intercept remote code.
Args:
patch_map (dict): A dictionary mapping the target module name
(e.g., "modeling_falcon") to the local path of the
custom Python file.
""" """
_patch_get_class_in_module()
_patch_resolve_trust_remote_code()
LOG.info("Kimi tokenizer patches applied successfully!")
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
patches_to_apply = { def patch_kimi_model():
"modeling_kimi": "modeling_kimi.py", """
"tokenization_kimi": "tokenization_kimi.py", Apply Kimi model patches.
} This is called during model loading.
Note: The core interception is already done by patch_kimi_tokenizer,
patch_map = {} but we keep this for any model-specific patches that might be needed.
for target_module, filename in patches_to_apply.items(): """
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename) _patch_get_class_in_module()
if patch_path and patch_path.exists(): _patch_resolve_trust_remote_code()
print(f"Found patch for '{target_module}' at '{patch_path}'") LOG.info("Kimi model patches applied successfully!")
patch_map[target_module] = patch_path
else:
raise FileNotFoundError(
f"Could not find the patch file '{filename}' "
f"in package '{KIMI_PATCH_PACKAGE}'"
)
original_modules = {}
injected_modules = []
for target_module_name, patch_file_path in patch_map.items():
if not Path(patch_file_path).exists():
print(f"Warning: Patch file not found at {patch_file_path}. Skipping.")
continue
# If the original module is already loaded, save it for restoration
if target_module_name in sys.modules:
original_modules[target_module_name] = sys.modules[target_module_name]
# Use importlib to load our custom file as a module
spec = importlib.util.spec_from_file_location(
target_module_name, patch_file_path
)
custom_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_module)
# Inject it into sys.modules
sys.modules[target_module_name] = custom_module
injected_modules.append(target_module_name)
try:
# Yield control back to the 'with' block
yield
finally:
# Cleanup: restore original modules or remove injected ones
for module_name in injected_modules:
if module_name in original_modules:
# Restore the original module if it existed
sys.modules[module_name] = original_modules[module_name]
else:
# Otherwise, just remove our injected module
del sys.modules[module_name]