move patching to post-model load to improve applicability

This commit is contained in:
Dan Saunders
2025-02-18 19:00:12 +00:00
parent c3d4f6e295
commit 945dcc5020
3 changed files with 108 additions and 92 deletions

View File

@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers import AutoConfig from transformers import AutoConfig
from transformers.modeling_utils import PreTrainedModel
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
@@ -119,12 +120,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_config = AutoConfig.from_pretrained(cfg["base_model"]) model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try: try:
# Dynamically import the module and attention class # Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}" module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -142,62 +137,86 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
# pylint: disable=protected-access # pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault): def patch_self_attn_lora(model: PreTrainedModel):
""" """
Given an `axolotl` config, this method patches the inferred attention class forward Patches the attention classes in a transformer model with optimized LoRA implementations.
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. model: A HuggingFace transformers model.
Raises: Raises:
AssertionError: If the required code blocks are not found in the attention AssertionError: If the required code blocks are not found in the attention
implementation. implementation.
""" """
attention_cls = get_attention_cls_from_config(cfg) # Find all attention modules in the model
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]
# Check if already patched if not attention_modules:
if hasattr(attention_cls, "_original_forward"): LOG.warning("No attention modules found in model")
LOG.info(f"{attention_cls.__name__} already patched")
return return
self_attn_forward = inspect.getsource(attention_cls.forward) attention_classes = {type(module) for module in attention_modules}
attention_cls._original_forward = self_attn_forward LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" for attention_cls in attention_classes:
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" # Skip if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) # Get and store original forward implementation
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = inspect.getsource(attention_cls.forward)
self_attn_forward = self_attn_forward.replace( attention_cls._original_forward = self_attn_forward
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Load necessary imports # Remove indentation
module_name = attention_cls.__module__ self_attn_forward, _ = detab_code(self_attn_forward)
module = importlib.import_module(module_name)
items_to_import = [] # Verify required code blocks exist
for item in dir(module): assert (
if item in self_attn_forward: ORIGINAL_QKV_CODE in self_attn_forward
items_to_import.append(item) ), f"Original QKV code not found in {attention_cls.__name__}"
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"
exec( # pylint: disable=exec-used # nosec B102 # Replace code blocks
f"from {module_name} import ({', '.join(items_to_import)})", self_attn_forward = self_attn_forward.replace(
globals(), ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") # Import necessary symbols from the attention module
attention_cls.forward = ( module_name = attention_cls.__module__
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821 module = importlib.import_module(module_name)
)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches( def apply_lora_kernel_patches(

View File

@@ -439,11 +439,6 @@ class ModelLoader:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1181,6 +1176,11 @@ class ModelLoader:
if self.cfg.adapter is not None: if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device) log_gpu_memory_usage(LOG, "after adapters", self.model.device)
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
self.apply_unsloth_lora_patch() self.apply_unsloth_lora_patch()
self.apply_lora_patch() self.apply_lora_patch()
@@ -1201,9 +1201,7 @@ def load_model(
reference_model: bool = False, reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """Load a model for a given configuration and tokenizer."""
Load a model for a given configuration and tokenizer.
"""
loader = ModelLoader( loader = ModelLoader(
cfg, cfg,
tokenizer, tokenizer,

View File

@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
apply_lora_mlp_swiglu, apply_lora_mlp_swiglu,
apply_lora_o, apply_lora_o,
apply_lora_qkv, apply_lora_qkv,
) )
from axolotl.monkeypatch.lora_kernels import ( from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [ MODEL_CONFIGS = [
@@ -65,15 +63,44 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config)) return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration(): @pytest.fixture
"""Test attention patching in integration context.""" def minimal_config():
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} "Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
return cfg
def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context."""
# Store the original implementation # Store the original implementation
original_forward = getattr(LlamaAttention, "forward") original_forward = getattr(LlamaAttention, "forward")
# Apply patch # Load model
patch_self_attn_lora(cfg) _, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Get the new forward method # Get the new forward method
patched_forward = LlamaAttention.forward patched_forward = LlamaAttention.forward
@@ -376,38 +403,10 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(): def test_kernel_training_integration(minimal_cfg):
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model # Load model
model, _ = load_model_and_tokenizer(cfg=cfg) model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Verify correct activation function # Verify correct activation function
layer = model.model.model.layers[0] layer = model.model.model.layers[0]