From ecac731922a43825158512286ac21f85ca9f953a Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 29 Apr 2025 16:18:49 -0400 Subject: [PATCH] auto-enable lora kernels where possible (#2589) * auto-enable lora kernels where possible * test * revert change to example yaml * naming * remove print * slight logic change --- src/axolotl/utils/schemas/config.py | 48 +++++++++++++++++++ .../lora_kernels/test_lora_kernel_patching.py | 43 +++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 36c18fd3c..bc25b1ab0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1315,6 +1315,54 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data + @model_validator(mode="before") + @classmethod + def check_auto_enable_lora_kernels(cls, data): + # Only proceed if using LoRA or QLoRA adapter + if data.get("adapter") in ["lora", "qlora"]: + # Skip if already set, using unsloth optimizations, or using 8-bit + unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] + kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + if ( + any(data.get(k) is not None for k in kernel_fields) + or any(data.get(k) for k in unsloth_fields) + or data.get("adapter") == "lora" + and data.get("load_in_8bit") + ): + return data + + # Check multi-GPU compatibility + capabilities = data.get("capabilities") + is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 + is_fsdp = data.get("fsdp") is not None + is_fsdp2 = ( + data.get("fsdp_config") is not None + and str(data.get("fsdp_config").get("fsdp_version")) == "2" + ) + + if ( + not is_multi_gpu + or (is_multi_gpu and not is_fsdp) + or (is_multi_gpu and is_fsdp2) + ): + # Auto-enable kernels if not explicitly set by user + if data.get("lora_mlp_kernel") is None: + data["lora_mlp_kernel"] = True + + if data.get("lora_qkv_kernel") is None: + data["lora_qkv_kernel"] = True + + if data.get("lora_o_kernel") is None: + data["lora_o_kernel"] = True + + LOG.warning( + "Auto-enabling LoRA kernel optimizations for faster training. " + + "Please explicitly set `lora_*_kernel` config values to `false` to disable. " + + "See https://docs.axolotl.ai/docs/lora_optims.html for more info." + ) + + return data + @model_validator(mode="before") @classmethod def check_adopt_torch_version(cls, data): diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index eb0c73225..f3e59b373 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -2,8 +2,11 @@ # pylint: disable=redefined-outer-name +from pathlib import Path + import pytest import torch +import yaml from accelerate.state import PartialState from peft import PeftModelForCausalLM, get_peft_config from transformers import AutoModelForCausalLM, LlamaForCausalLM @@ -11,6 +14,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention +from axolotl.cli.config import load_cfg from axolotl.kernels.lora import ( apply_lora_mlp_geglu, apply_lora_mlp_swiglu, @@ -421,3 +425,42 @@ def test_kernel_training_integration(): # Verify correct activation function layer = model.model.model.layers[0] assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu + + +def test_kernel_training_integration_auto_enable(temp_dir): + """Test model loading with auto-enabled kernel patches.""" + # Create minimal config without explicitly setting kernel options + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "lora_target_linear": True, + "sequence_len": 1024, + } + ) + + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + + # Verify kernel options were auto-enabled in the config + assert cfg.lora_mlp_kernel is True + assert cfg.lora_qkv_kernel is True + assert cfg.lora_o_kernel is True