diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 81e4dd786..569a071a4 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -190,6 +190,13 @@ class PatchManager: apply_mistral_tokenizer_image_patch() + if self.cfg.model_config_type == "kimi_linear": + from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( + patch_kimi_model, + ) + + patch_kimi_model() + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: diff --git a/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py index d1b54af67..72a6702cd 100644 --- a/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py +++ b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py @@ -27,6 +27,93 @@ def get_patch_file_path(package_dot_path: str, filename: str) -> Path: return None +def patch_kimi_model(): + """ + Apply Kimi model patches by hijacking Transformers' dynamic module loading. + This intercepts the remote code loading and replaces it with our local patches. + """ + from transformers.dynamic_module_utils import get_class_in_module + + KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear" + + # Store original function + original_get_class_in_module = get_class_in_module + + def patched_get_class_in_module(class_name, module_path): + """Patched version that returns our local modules instead of remote ones.""" + # Check if this is a Kimi model module + if "modeling_kimi" in module_path: + print("Intercepting remote Kimi modeling module, using local patch instead") + # Load our local modeling_kimi.py instead + patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "modeling_kimi.py") + if patch_path and patch_path.exists(): + spec = importlib.util.spec_from_file_location( + "modeling_kimi", patch_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + if "tokenization_kimi" in module_path: + print( + "Intercepting remote Kimi tokenizer module, using local patch instead" + ) + # Load our local tokenization_kimi.py instead + patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, "tokenization_kimi.py") + if patch_path and patch_path.exists(): + spec = importlib.util.spec_from_file_location( + "tokenization_kimi", patch_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + # For all other modules, use the original function + return original_get_class_in_module(class_name, module_path) + + # Apply the monkey patch + import transformers.dynamic_module_utils + + transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module + + # Also patch the resolve_trust_remote_code to handle auto_map + from transformers.dynamic_module_utils import resolve_trust_remote_code + + original_resolve_trust_remote_code = resolve_trust_remote_code + + def patched_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs): + """Patched version to handle Kimi model auto_map.""" + # Check if this is a Kimi model + if "kimi" in repo_id.lower() or "kimi" in model_id.lower(): + print(f"Resolving trust remote code for Kimi model: {repo_id}") + # Get the original result + result = original_resolve_trust_remote_code( + repo_id, model_id, *args, **kwargs + ) + + # If it contains auto_map for Kimi, replace with our local files + if hasattr(result, "get") and result.get("auto_map"): + auto_map = result["auto_map"].copy() + # Replace remote modules with our local ones + for key in auto_map: + if "modeling_kimi" in auto_map[key]: + auto_map[key] = "modeling_kimi" + if "tokenization_kimi" in auto_map[key]: + auto_map[key] = "tokenization_kimi" + result["auto_map"] = auto_map + result["trust_remote_code"] = True + + return result + + return original_resolve_trust_remote_code(repo_id, model_id, *args, **kwargs) + + transformers.dynamic_module_utils.resolve_trust_remote_code = ( + patched_resolve_trust_remote_code + ) + + print("Kimi model patches applied successfully!") + + # The context manager code from before remains the same @contextmanager def patch_hf_imports():