fix: attempt patch kimi remote

This commit is contained in:
NanoCode012
2025-12-22 18:11:00 +07:00
parent 2c0272fd55
commit 001f4205f3
2 changed files with 94 additions and 0 deletions

View File

@@ -190,6 +190,13 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
)
patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -27,6 +27,93 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
return None
def patch_kimi_model():
"""
Apply Kimi model patches by hijacking Transformers' dynamic module loading.
This intercepts the remote code loading and replaces it with our local patches.
"""
from transformers.dynamic_module_utils import get_class_in_module
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
# Store original function
original_get_class_in_module = get_class_in_module
def patched_get_class_in_module(class_name, module_path):
"""Patched version that returns our local modules instead of remote ones."""
# Check if this is a Kimi model module
if "modeling_kimi" in module_path:
print("Intercepting remote Kimi modeling module, using local patch instead")
# Load our local modeling_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py")
if patch_path and patch_path.exists():
spec = importlib.util.spec_from_file_location(
"modeling_kimi", patch_path
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, class_name)
if "tokenization_kimi" in module_path:
print(
"Intercepting remote Kimi tokenizer module, using local patch instead"
)
# Load our local tokenization_kimi.py instead
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py")
if patch_path and patch_path.exists():
spec = importlib.util.spec_from_file_location(
"tokenization_kimi", patch_path
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, class_name)
# For all other modules, use the original function
return original_get_class_in_module(class_name, module_path)
# Apply the monkey patch
import transformers.dynamic_module_utils
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
# Also patch the resolve_trust_remote_code to handle auto_map
from transformers.dynamic_module_utils import resolve_trust_remote_code
original_resolve_trust_remote_code = resolve_trust_remote_code
def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs):
"""Patched version to handle Kimi model auto_map."""
# Check if this is a Kimi model
if "kimi" in repo_id.lower() or "kimi" in model_id.lower():
print(f"Resolving trust remote code for Kimi model: {repo_id}")
# Get the original result
result = original_resolve_trust_remote_code(
repo_id, model_id, *args, **kwargs
)
# If it contains auto_map for Kimi, replace with our local files
if hasattr(result, "get") and result.get("auto_map"):
auto_map = result["auto_map"].copy()
# Replace remote modules with our local ones
for key in auto_map:
if "modeling_kimi" in auto_map[key]:
auto_map[key] = "modeling_kimi"
if "tokenization_kimi" in auto_map[key]:
auto_map[key] = "tokenization_kimi"
result["auto_map"] = auto_map
result["trust_remote_code"] = True
return result
return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs)
transformers.dynamic_module_utils.resolve_trust_remote_code = (
patched_resolve_trust_remote_code
)
print("Kimi model patches applied successfully!")
# The context manager code from before remains the same
@contextmanager
def patch_hf_imports():