diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 23f79d368..ca8fd1258 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -166,6 +166,17 @@ class PatchManager: def _apply_self_attention_lora_patch(self): """Apply self-attention LoRA patches if configured.""" if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: + # Only patch if conditions are met + can_patch = ( + self.cfg.lora_dropout == 0 + if hasattr(self.cfg, "lora_dropout") + else True + ) # default to True if lora_dropout is not set + + if not can_patch: + LOG.warning("Cannot patch self-attention - requires no dropout") + return + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora patch_self_attn_lora(self.cfg) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 63fbfa359..586412dd7 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -274,6 +274,29 @@ def find_mlp_in_layer( ) +def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]: + """ + Get the layers of the model. Handles text-only and multimodal models. + + Args: + model: A PEFT model. + + Returns: + A list of layers. + """ + pretrained_model = model.model + + # check for multimodal models first + if hasattr(pretrained_model, "language_model"): + return pretrained_model.language_model.layers + if hasattr(pretrained_model, "model"): + return pretrained_model.model.layers + + raise NotImplementedError( + f"Model type {model.config.model_type} is not supported yet. Please create an Issue." + ) + + def apply_lora_kernel_patches( model: PeftModelForCausalLM, cfg: DictDefault ) -> PeftModelForCausalLM: @@ -345,17 +368,7 @@ def apply_lora_kernel_patches( if activation not in SUPPORTED_ACTIVATIONS: raise NotImplementedError(f"Activation {activation} is not supported") - layers = [] - # check for multimodal models first - pretrained_model = model.model - if hasattr(pretrained_model, "language_model"): - layers = pretrained_model.language_model.layers - elif hasattr(pretrained_model, "model"): - layers = pretrained_model.model.layers - else: - raise NotImplementedError( - f"Model type {model.config.model_type} is not supported yet. Please create an Issue." - ) + layers = get_layers(model) # Patch each layer for layer in layers: diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 76c383a92..56ce5a8b9 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader from axolotl.loaders.tokenizer import load_tokenizer from axolotl.monkeypatch.lora_kernels import ( apply_lora_kernel_patches, + find_self_attn_in_layer, get_attention_cls_from_config, + get_layers, patch_self_attn_lora, ) from axolotl.utils.dict import DictDefault @@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir): break assert found_patched_attn + + +def test_kernel_training_integration_dropout_non_zero(): + """Test model loading with dropout non-zero should not patch.""" + + from axolotl.cli.utils import load_model_and_tokenizer + + # Create minimal config + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lora_target_linear": True, + "sequence_len": 1024, + } + ) + + # Get original attention class + attention_cls = get_attention_cls_from_config(cfg) + + # Store original state before patching + original_forward_method = attention_cls.forward + + # Load model + model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg) + + # We call modelloader as that's where the patches are applied + # despite the fact that we're not using it to load the model + model_loader = ModelLoader(cfg, tokenizer) + + # Apply patch + model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access + + # Verify patch was not applied + assert attention_cls.forward == original_forward_method + + # Apply apply_lora_kernel_patches + model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access + model + ) + + # Verify patch was not applied + layers = get_layers(model) + for layer in layers: + for self_attn in find_self_attn_in_layer(layer): + assert not hasattr(self_attn, "apply_qkv") + assert not hasattr(self_attn, "apply_o")