From 945dcc50207d8dde46454d6b07c7a7aac5613da3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 18 Feb 2025 19:00:12 +0000 Subject: [PATCH] move patching to post-model load to improve applicability --- src/axolotl/monkeypatch/lora_kernels.py | 109 ++++++++++-------- src/axolotl/utils/models.py | 14 +-- .../lora_kernels/test_lora_kernel_patching.py | 77 ++++++------- 3 files changed, 108 insertions(+), 92 deletions(-) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index d59fd22c9..bc9d62ed1 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -11,6 +11,7 @@ from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn from transformers import AutoConfig +from transformers.modeling_utils import PreTrainedModel from axolotl.kernels.lora import ( apply_lora_mlp_geglu, @@ -119,12 +120,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: model_config = AutoConfig.from_pretrained(cfg["base_model"]) model_type = model_config.model_type - # Special case for model_type = "qwen2" - if model_type == "qwen2": - from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention - - return Qwen2Attention - try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" @@ -142,62 +137,86 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: # pylint: disable=protected-access -def patch_self_attn_lora(cfg: DictDefault): +def patch_self_attn_lora(model: PreTrainedModel): """ - Given an `axolotl` config, this method patches the inferred attention class forward - pass with optimized LoRA implementations. - - It modifies the attention class to use optimized QKV and output projections. The - original implementation is preserved and can be restored if needed. + Patches the attention classes in a transformer model with optimized LoRA implementations. Args: - cfg: Dictionary mapping `axolotl` config keys to values. + model: A HuggingFace transformers model. Raises: AssertionError: If the required code blocks are not found in the attention implementation. """ - attention_cls = get_attention_cls_from_config(cfg) + # Find all attention modules in the model + attention_modules = [ + module + for module in model.modules() + if "attention" in module.__class__.__name__.lower() + and hasattr(module, "forward") + ] - # Check if already patched - if hasattr(attention_cls, "_original_forward"): - LOG.info(f"{attention_cls.__name__} already patched") + if not attention_modules: + LOG.warning("No attention modules found in model") return - self_attn_forward = inspect.getsource(attention_cls.forward) - attention_cls._original_forward = self_attn_forward - self_attn_forward, _ = detab_code(self_attn_forward) + attention_classes = {type(module) for module in attention_modules} + LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}") - assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" - assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" + for attention_cls in attention_classes: + # Skip if already patched + if hasattr(attention_cls, "_original_forward"): + LOG.info(f"{attention_cls.__name__} already patched") + continue - self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) - self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) - self_attn_forward = self_attn_forward.replace( - "def forward(", - "def axolotl_attn_forward(", - 1, - ) + # Get and store original forward implementation + self_attn_forward = inspect.getsource(attention_cls.forward) + attention_cls._original_forward = self_attn_forward - # Load necessary imports - module_name = attention_cls.__module__ - module = importlib.import_module(module_name) + # Remove indentation + self_attn_forward, _ = detab_code(self_attn_forward) - items_to_import = [] - for item in dir(module): - if item in self_attn_forward: - items_to_import.append(item) + # Verify required code blocks exist + assert ( + ORIGINAL_QKV_CODE in self_attn_forward + ), f"Original QKV code not found in {attention_cls.__name__}" + assert ( + ORIGINAL_O_CODE in self_attn_forward + ), f"Original O code not found in {attention_cls.__name__}" - exec( # pylint: disable=exec-used # nosec B102 - f"from {module_name} import ({', '.join(items_to_import)})", - globals(), - ) - exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + # Replace code blocks + self_attn_forward = self_attn_forward.replace( + ORIGINAL_QKV_CODE, PATCHED_QKV_CODE + ) + self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) + self_attn_forward = self_attn_forward.replace( + "def forward(", + "def axolotl_attn_forward(", + 1, + ) - LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") - attention_cls.forward = ( - axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821 - ) + # Import necessary symbols from the attention module + module_name = attention_cls.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in self_attn_forward: + items_to_import.append(item) + + if items_to_import: + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + + # Execute the new implementation + exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") + attention_cls.forward = ( + axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821 + ) def apply_lora_kernel_patches( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 377f08605..cb73b5ff4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -439,11 +439,6 @@ class ModelLoader: patch_mistral_cross_entropy() - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora - - patch_self_attn_lora(self.cfg) - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -1181,6 +1176,11 @@ class ModelLoader: if self.cfg.adapter is not None: log_gpu_memory_usage(LOG, "after adapters", self.model.device) + if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora + + patch_self_attn_lora(self.model) + self.apply_unsloth_lora_patch() self.apply_lora_patch() @@ -1201,9 +1201,7 @@ def load_model( reference_model: bool = False, **kwargs, # pylint: disable=unused-argument ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ + """Load a model for a given configuration and tokenizer.""" loader = ModelLoader( cfg, tokenizer, diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 4e3373367..ee6f76ee6 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention +from axolotl.cli.utils import load_model_and_tokenizer from axolotl.kernels.lora import ( apply_lora_mlp_geglu, apply_lora_mlp_swiglu, apply_lora_o, apply_lora_qkv, ) -from axolotl.monkeypatch.lora_kernels import ( - apply_lora_kernel_patches, - patch_self_attn_lora, -) +from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches from axolotl.utils.dict import DictDefault MODEL_CONFIGS = [ @@ -65,15 +63,44 @@ def small_llama_model(): return LlamaForCausalLM(LlamaConfig(**config)) -def test_attention_patching_integration(): - """Test attention patching in integration context.""" - cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} +@pytest.fixture +def minimal_config(): + "Config of real HuggingFace Hub model" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "lora_target_linear": True, + "sequence_len": 1024, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + return cfg + + +def test_attention_patching_integration(minimal_cfg): + """Test attention patching in integration context.""" # Store the original implementation original_forward = getattr(LlamaAttention, "forward") - # Apply patch - patch_self_attn_lora(cfg) + # Load model + _, _ = load_model_and_tokenizer(cfg=minimal_cfg) # Get the new forward method patched_forward = LlamaAttention.forward @@ -376,38 +403,10 @@ def test_model_architecture(model_config): # pylint: disable=duplicate-code -def test_kernel_training_integration(): +def test_kernel_training_integration(minimal_cfg): """Test model loading with kernel patches enabled.""" - from axolotl.cli.utils import load_model_and_tokenizer - - # Create minimal config - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", - "learning_rate": 0.000001, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - } - ], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "adapter": "lora", - "lora_r": 8, - "lora_alpha": 16, - "lora_dropout": 0.0, - "lora_target_linear": True, - "sequence_len": 1024, - "lora_mlp_kernel": True, - "lora_qkv_kernel": True, - "lora_o_kernel": True, - } - ) - # Load model - model, _ = load_model_and_tokenizer(cfg=cfg) + model, _ = load_model_and_tokenizer(cfg=minimal_cfg) # Verify correct activation function layer = model.model.model.layers[0]