Lora kernels fix (#2732)

* fix lora kernel patching and improve test

* simplification
This commit is contained in:
Dan Saunders
2025-05-28 10:03:43 -04:00
committed by GitHub
parent 65c5481120
commit 2962a398b7
3 changed files with 52 additions and 12 deletions

View File

@@ -14,7 +14,13 @@ import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training
from peft import (
PeftConfig,
PeftMixedModel,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
@@ -139,7 +145,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora"
def load(self) -> tuple[PreTrainedModel, PeftConfig | None]:
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
Returns:

View File

@@ -59,9 +59,10 @@ class PatchManager:
self._apply_gradient_checkpointing_patches()
self._patch_attention()
self._apply_multipack_patches()
self._patch_loss_llama()
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_unsloth_self_attention_patch()
self._apply_self_attention_lora_patch()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -169,9 +170,9 @@ class PatchManager:
patch_mistral_cross_entropy()
def _apply_unsloth_self_attention_patch(self):
"""Apply Unsloth self-attention patches if configured."""
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
@@ -206,9 +207,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
if self.cfg.is_llama_derived_model:
self._patch_loss_llama()
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
@@ -235,6 +233,9 @@ class PatchManager:
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
return
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
@@ -314,8 +315,6 @@ class PatchManager:
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
self._patch_loss_llama()
if self.cfg.flash_attention:
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
elif self.cfg.xformers_attention: