From 479f5e18dd2136a2e2719aebdb9d383272a7d590 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 18 Feb 2025 19:08:27 +0000 Subject: [PATCH] Small updates --- src/axolotl/monkeypatch/lora_kernels.py | 3 +++ src/axolotl/utils/models.py | 12 +++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index bc9d62ed1..0daf16a29 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -141,6 +141,9 @@ def patch_self_attn_lora(model: PreTrainedModel): """ Patches the attention classes in a transformer model with optimized LoRA implementations. + It modifies the attention class to use optimized QKV and output projections. The + original implementation is preserved and can be restored if needed. + Args: model: A HuggingFace transformers model. diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cb73b5ff4..64ed5b600 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1023,6 +1023,12 @@ class ModelLoader: integrate_rope_embeddings() def apply_lora_patch(self) -> None: + """Applies patching relevant to LoRA Triton kernels if enabled.""" + if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora + + patch_self_attn_lora(self.model) + if ( self.cfg.lora_mlp_kernel or self.cfg.lora_qkv_kernel @@ -1176,11 +1182,7 @@ class ModelLoader: if self.cfg.adapter is not None: log_gpu_memory_usage(LOG, "after adapters", self.model.device) - if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: - from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora - - patch_self_attn_lora(self.model) - + # TODO: Deprecate this. self.apply_unsloth_lora_patch() self.apply_lora_patch()