move patching to post-model load to improve applicability

This commit is contained in:
Dan Saunders
2025-02-18 19:00:12 +00:00
parent c3d4f6e295
commit 945dcc5020
3 changed files with 108 additions and 92 deletions

View File

@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
@@ -65,15 +63,44 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
@pytest.fixture
def minimal_config():
"Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
return cfg
def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context."""
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Load model
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
@@ -376,38 +403,10 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(minimal_cfg):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Verify correct activation function
layer = model.model.model.layers[0]