Small updates

This commit is contained in:
Dan Saunders
2025-02-18 19:08:27 +00:00
parent 945dcc5020
commit 479f5e18dd
2 changed files with 10 additions and 5 deletions

View File

@@ -141,6 +141,9 @@ def patch_self_attn_lora(model: PreTrainedModel):
"""
Patches the attention classes in a transformer model with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args:
model: A HuggingFace transformers model.

View File

@@ -1023,6 +1023,12 @@ class ModelLoader:
integrate_rope_embeddings()
def apply_lora_patch(self) -> None:
"""Applies patching relevant to LoRA Triton kernels if enabled."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
@@ -1176,11 +1182,7 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
# TODO: Deprecate this.
self.apply_unsloth_lora_patch()
self.apply_lora_patch()