diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 681e5d335..1d26a99dd 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -14,7 +14,13 @@ import torch import transformers import transformers.modeling_utils from accelerate import init_empty_weights -from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training +from peft import ( + PeftConfig, + PeftMixedModel, + PeftModel, + PeftModelForCausalLM, + prepare_model_for_kbit_training, +) from transformers import ( AutoModelForCausalLM, AutoModelForVision2Seq, @@ -139,7 +145,7 @@ class ModelLoader: """Property that determines if FSDP with QLoRA is enabled.""" return self.cfg.fsdp and self.cfg.adapter == "qlora" - def load(self) -> tuple[PreTrainedModel, PeftConfig | None]: + def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. Returns: diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index ce1f5cf70..56888b607 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -59,9 +59,10 @@ class PatchManager: self._apply_gradient_checkpointing_patches() self._patch_attention() self._apply_multipack_patches() + self._patch_loss_llama() self._patch_llama_derived_model() self._apply_mistral_cross_entropy_patch() - self._apply_unsloth_self_attention_patch() + self._apply_self_attention_lora_patch() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -169,9 +170,9 @@ class PatchManager: patch_mistral_cross_entropy() - def _apply_unsloth_self_attention_patch(self): - """Apply Unsloth self-attention patches if configured.""" - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + def _apply_self_attention_lora_patch(self): + """Apply self-attention LoRA patches if configured.""" + if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora patch_self_attn_lora(self.cfg) @@ -206,9 +207,6 @@ class PatchManager: has_remote_code=has_remote_code, ) - if self.cfg.is_llama_derived_model: - self._patch_loss_llama() - def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): @@ -235,6 +233,9 @@ class PatchManager: def _patch_loss_llama(self): """Patch loss functions and other optimizations for LLaMA models.""" + if not self.cfg.is_llama_derived_model: + return + if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: from axolotl.monkeypatch.llama_attn_hijack_flash import ( patch_fa_llama_cross_entropy, @@ -314,8 +315,6 @@ class PatchManager: and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.sample_packing ): - self._patch_loss_llama() - if self.cfg.flash_attention: self._patch_llama_flash_attention(packed=self.cfg.sample_packing) elif self.cfg.xformers_attention: diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index f6b7ee9b9..76c383a92 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -21,8 +21,11 @@ from axolotl.kernels.lora import ( apply_lora_o, apply_lora_qkv, ) +from axolotl.loaders.model import ModelLoader +from axolotl.loaders.tokenizer import load_tokenizer from axolotl.monkeypatch.lora_kernels import ( apply_lora_kernel_patches, + get_attention_cls_from_config, patch_self_attn_lora, ) from axolotl.utils.dict import DictDefault @@ -80,7 +83,7 @@ def small_llama_model(): ) def test_attention_patching_integration(model_name, attention_cls): """Test attention patching in integration context.""" - cfg = {"base_model": model_name} + cfg = DictDefault({"base_model": model_name}) # Store the original implementation original_forward = getattr(attention_cls, "forward") @@ -466,3 +469,35 @@ def test_kernel_training_integration_auto_enable(temp_dir): assert cfg.lora_mlp_kernel is True assert cfg.lora_qkv_kernel is True assert cfg.lora_o_kernel is True + + # Get the attention class before patching to check for side effects + attention_cls = get_attention_cls_from_config(cfg) + + # Store original state before patching + original_forward_method = attention_cls.forward + + # Load the model (this should trigger the patches) + tokenizer = load_tokenizer(cfg) + model, _ = ModelLoader(cfg, tokenizer).load() + + # Test side effects of patch_self_attn_lora + assert hasattr(attention_cls, "_original_forward") + assert attention_cls.forward != original_forward_method + + # Find at least one self-attention module and verify it has the patched methods + found_patched_attn = False + for layer in model.model.model.layers: + if hasattr(layer, "self_attn"): + self_attn = layer.self_attn + if all( + hasattr(self_attn, proj) + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] + ): + # These methods should be added by apply_lora_kernel_patches + assert hasattr(self_attn, "apply_qkv") and callable(self_attn.apply_qkv) + assert hasattr(self_attn, "apply_o") and callable(self_attn.apply_o) + + found_patched_attn = True + break + + assert found_patched_attn