diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd index ea417a8e9..0c7ad7a54 100644 --- a/docs/lora_optims.qmd +++ b/docs/lora_optims.qmd @@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU -(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function -Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was -to leverage operator fusion and tensor re-use in order to improve speed and reduce -memory usage during the forward and backward passes of these calculations. +(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU +activation function Triton kernels, and (2) LoRA MLP and attention custom autograd +functions. Our goal was to leverage operator fusion and tensor re-use in order to +improve speed and reduce memory usage during the forward and backward passes of these +calculations. We currently support several common model architectures, including (but not limited to): @@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT. - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491) -- Targeted LoRA adapters cannot use Dropout - - This may limit model expressivity / cause overfitting -- Targeted LoRA adapters cannot have bias terms +- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`) - This may limit model expressivity +- Adapters that already include bias terms are supported. -Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to -be re-finetuned without these features in order to be useful. +Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned +without it in order to be as performant. ## Implementation details diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index e845dc6ce..df1e8cced 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -323,8 +323,8 @@ def apply_lora_kernel_patches( AssertionError: If multiple adapters are active (currently unsupported). Note: - The optimizations require LoRA adapters with no dropout and no bias terms. The - function will skip patching if these conditions aren't met. + The optimizations require LoRA adapters with no dropout. The function will skip + patching if that condition isn't met. """ if not isinstance(model, PeftModelForCausalLM): raise TypeError("Model must be a PeftModelForCausalLM") @@ -340,10 +340,10 @@ def apply_lora_kernel_patches( lora_config = model.model.peft_config[active_adapter] # Only patch if conditions are met - can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none" + can_patch = lora_config.lora_dropout == 0 if not can_patch: - LOG.warning("Cannot patch layers - requires no dropout and no bias") + LOG.warning("Cannot patch layers - requires `lora_dropout: 0`") LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file") return model diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 2180eb99d..bbd2d4ffa 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation): assert layer.mlp.forward.__func__ is expected_activation -def test_kernel_patch_conditions(): - """Test various conditions that should prevent kernel patching.""" - test_configs = [ - # Dropout prevents patching - { - "peft_type": "LORA", - "task_type": "CAUSAL_LM", - "r": 8, - "lora_alpha": 16, - "target_modules": ["gate_proj", "up_proj", "down_proj"], - "lora_dropout": 0.1, - "bias": "none", - }, - # Bias prevents patching - { - "peft_type": "LORA", - "task_type": "CAUSAL_LM", - "r": 8, - "lora_alpha": 16, - "target_modules": ["gate_proj", "up_proj", "down_proj"], - "lora_dropout": 0, - "bias": "lora_only", - }, - ] +def test_kernel_patch_requires_zero_dropout(): + """Kernel patching should be skipped when dropout is enabled.""" + config = { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0.1, + "bias": "none", + } - for config in test_configs: - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") - peft_config = get_peft_config(config) - model = PeftModelForCausalLM(model, peft_config) - cfg = DictDefault({"lora_mlp_kernel": True}) + model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(model, peft_config) + cfg = DictDefault({"lora_mlp_kernel": True}) - # Should not patch - patched_model = apply_lora_kernel_patches(model, cfg) - layer = patched_model.model.model.layers[0].mlp + patched_model = apply_lora_kernel_patches(model, cfg) + layer = patched_model.model.model.layers[0].mlp - # Verify no patches applied - assert layer.forward.__func__ is not apply_lora_mlp_swiglu - assert layer.forward.__func__ is not apply_lora_mlp_geglu + # Verify no patches applied when dropout is non-zero + assert layer.forward.__func__ is not apply_lora_mlp_swiglu + assert layer.forward.__func__ is not apply_lora_mlp_geglu + + +def test_kernel_patch_with_bias_enabled(): + """Kernel patching should succeed when LoRA bias is enabled.""" + config = { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "lora_only", + } + + model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(model, peft_config) + cfg = DictDefault({"lora_mlp_kernel": True}) + + patched_model = apply_lora_kernel_patches(model, cfg) + layer = patched_model.model.model.layers[0].mlp + + # Verify patches applied when bias support is enabled + assert layer.forward.__func__ is apply_lora_mlp_swiglu def test_kernel_config_options():