fix: attempt patch kimi remote
This commit is contained in:
@@ -190,6 +190,13 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "kimi_linear":
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_model,
|
||||
)
|
||||
|
||||
patch_kimi_model()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
@@ -27,6 +27,93 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||
return None
|
||||
|
||||
|
||||
def patch_kimi_model():
|
||||
"""
|
||||
Apply Kimi model patches by hijacking Transformers' dynamic module loading.
|
||||
This intercepts the remote code loading and replaces it with our local patches.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import get_class_in_module
|
||||
|
||||
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||
|
||||
# Store original function
|
||||
original_get_class_in_module = get_class_in_module
|
||||
|
||||
def patched_get_class_in_module(class_name, module_path):
|
||||
"""Patched version that returns our local modules instead of remote ones."""
|
||||
# Check if this is a Kimi model module
|
||||
if "modeling_kimi" in module_path:
|
||||
print("Intercepting remote Kimi modeling module, using local patch instead")
|
||||
# Load our local modeling_kimi.py instead
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py")
|
||||
if patch_path and patch_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"modeling_kimi", patch_path
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return getattr(module, class_name)
|
||||
|
||||
if "tokenization_kimi" in module_path:
|
||||
print(
|
||||
"Intercepting remote Kimi tokenizer module, using local patch instead"
|
||||
)
|
||||
# Load our local tokenization_kimi.py instead
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py")
|
||||
if patch_path and patch_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tokenization_kimi", patch_path
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return getattr(module, class_name)
|
||||
|
||||
# For all other modules, use the original function
|
||||
return original_get_class_in_module(class_name, module_path)
|
||||
|
||||
# Apply the monkey patch
|
||||
import transformers.dynamic_module_utils
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||
|
||||
# Also patch the resolve_trust_remote_code to handle auto_map
|
||||
from transformers.dynamic_module_utils import resolve_trust_remote_code
|
||||
|
||||
original_resolve_trust_remote_code = resolve_trust_remote_code
|
||||
|
||||
def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs):
|
||||
"""Patched version to handle Kimi model auto_map."""
|
||||
# Check if this is a Kimi model
|
||||
if "kimi" in repo_id.lower() or "kimi" in model_id.lower():
|
||||
print(f"Resolving trust remote code for Kimi model: {repo_id}")
|
||||
# Get the original result
|
||||
result = original_resolve_trust_remote_code(
|
||||
repo_id, model_id, *args, **kwargs
|
||||
)
|
||||
|
||||
# If it contains auto_map for Kimi, replace with our local files
|
||||
if hasattr(result, "get") and result.get("auto_map"):
|
||||
auto_map = result["auto_map"].copy()
|
||||
# Replace remote modules with our local ones
|
||||
for key in auto_map:
|
||||
if "modeling_kimi" in auto_map[key]:
|
||||
auto_map[key] = "modeling_kimi"
|
||||
if "tokenization_kimi" in auto_map[key]:
|
||||
auto_map[key] = "tokenization_kimi"
|
||||
result["auto_map"] = auto_map
|
||||
result["trust_remote_code"] = True
|
||||
|
||||
return result
|
||||
|
||||
return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs)
|
||||
|
||||
transformers.dynamic_module_utils.resolve_trust_remote_code = (
|
||||
patched_resolve_trust_remote_code
|
||||
)
|
||||
|
||||
print("Kimi model patches applied successfully!")
|
||||
|
||||
|
||||
# The context manager code from before remains the same
|
||||
@contextmanager
|
||||
def patch_hf_imports():
|
||||
|
||||
Reference in New Issue
Block a user