Lora kernels fix (#2732)
* fix lora kernel patching and improve test * simplification
This commit is contained in:
@@ -14,7 +14,13 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training
|
from peft import (
|
||||||
|
PeftConfig,
|
||||||
|
PeftMixedModel,
|
||||||
|
PeftModel,
|
||||||
|
PeftModelForCausalLM,
|
||||||
|
prepare_model_for_kbit_training,
|
||||||
|
)
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@@ -139,7 +145,7 @@ class ModelLoader:
|
|||||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||||
|
|
||||||
def load(self) -> tuple[PreTrainedModel, PeftConfig | None]:
|
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||||
"""Load and prepare the model with all configurations and patches.
|
"""Load and prepare the model with all configurations and patches.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -59,9 +59,10 @@ class PatchManager:
|
|||||||
self._apply_gradient_checkpointing_patches()
|
self._apply_gradient_checkpointing_patches()
|
||||||
self._patch_attention()
|
self._patch_attention()
|
||||||
self._apply_multipack_patches()
|
self._apply_multipack_patches()
|
||||||
|
self._patch_loss_llama()
|
||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_unsloth_self_attention_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -169,9 +170,9 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_mistral_cross_entropy()
|
patch_mistral_cross_entropy()
|
||||||
|
|
||||||
def _apply_unsloth_self_attention_patch(self):
|
def _apply_self_attention_lora_patch(self):
|
||||||
"""Apply Unsloth self-attention patches if configured."""
|
"""Apply self-attention LoRA patches if configured."""
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
@@ -206,9 +207,6 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
|
||||||
self._patch_loss_llama()
|
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
@@ -235,6 +233,9 @@ class PatchManager:
|
|||||||
|
|
||||||
def _patch_loss_llama(self):
|
def _patch_loss_llama(self):
|
||||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||||
|
if not self.cfg.is_llama_derived_model:
|
||||||
|
return
|
||||||
|
|
||||||
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
patch_fa_llama_cross_entropy,
|
patch_fa_llama_cross_entropy,
|
||||||
@@ -314,8 +315,6 @@ class PatchManager:
|
|||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
self._patch_loss_llama()
|
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
|
|||||||
@@ -21,8 +21,11 @@ from axolotl.kernels.lora import (
|
|||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
apply_lora_qkv,
|
apply_lora_qkv,
|
||||||
)
|
)
|
||||||
|
from axolotl.loaders.model import ModelLoader
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
from axolotl.monkeypatch.lora_kernels import (
|
from axolotl.monkeypatch.lora_kernels import (
|
||||||
apply_lora_kernel_patches,
|
apply_lora_kernel_patches,
|
||||||
|
get_attention_cls_from_config,
|
||||||
patch_self_attn_lora,
|
patch_self_attn_lora,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -80,7 +83,7 @@ def small_llama_model():
|
|||||||
)
|
)
|
||||||
def test_attention_patching_integration(model_name, attention_cls):
|
def test_attention_patching_integration(model_name, attention_cls):
|
||||||
"""Test attention patching in integration context."""
|
"""Test attention patching in integration context."""
|
||||||
cfg = {"base_model": model_name}
|
cfg = DictDefault({"base_model": model_name})
|
||||||
|
|
||||||
# Store the original implementation
|
# Store the original implementation
|
||||||
original_forward = getattr(attention_cls, "forward")
|
original_forward = getattr(attention_cls, "forward")
|
||||||
@@ -466,3 +469,35 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
assert cfg.lora_mlp_kernel is True
|
assert cfg.lora_mlp_kernel is True
|
||||||
assert cfg.lora_qkv_kernel is True
|
assert cfg.lora_qkv_kernel is True
|
||||||
assert cfg.lora_o_kernel is True
|
assert cfg.lora_o_kernel is True
|
||||||
|
|
||||||
|
# Get the attention class before patching to check for side effects
|
||||||
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
# Store original state before patching
|
||||||
|
original_forward_method = attention_cls.forward
|
||||||
|
|
||||||
|
# Load the model (this should trigger the patches)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||||
|
|
||||||
|
# Test side effects of patch_self_attn_lora
|
||||||
|
assert hasattr(attention_cls, "_original_forward")
|
||||||
|
assert attention_cls.forward != original_forward_method
|
||||||
|
|
||||||
|
# Find at least one self-attention module and verify it has the patched methods
|
||||||
|
found_patched_attn = False
|
||||||
|
for layer in model.model.model.layers:
|
||||||
|
if hasattr(layer, "self_attn"):
|
||||||
|
self_attn = layer.self_attn
|
||||||
|
if all(
|
||||||
|
hasattr(self_attn, proj)
|
||||||
|
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
):
|
||||||
|
# These methods should be added by apply_lora_kernel_patches
|
||||||
|
assert hasattr(self_attn, "apply_qkv") and callable(self_attn.apply_qkv)
|
||||||
|
assert hasattr(self_attn, "apply_o") and callable(self_attn.apply_o)
|
||||||
|
|
||||||
|
found_patched_attn = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert found_patched_attn
|
||||||
|
|||||||
Reference in New Issue
Block a user