* feat: add support for multimodal in lora kernels * fix: improve multimodal checks * fix: add fallback for model config * chor: add gemma3 to docs
133 lines
4.4 KiB
Plaintext
133 lines
4.4 KiB
Plaintext
---
|
|
title: "LoRA Optimizations"
|
|
description: "Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning"
|
|
---
|
|
|
|
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
|
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
|
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
|
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
|
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
|
memory usage during the forward and backward passes of these calculations.
|
|
|
|
We currently support several common model architectures, including (but not limited to):
|
|
|
|
- `llama`
|
|
- `mistral`
|
|
- `qwen2`
|
|
- `gemma`
|
|
- `gemma2`
|
|
- `gemma3`
|
|
|
|
<details>
|
|
|
|
The set of models we support is currently limited by our attention patching strategy,
|
|
which assumes (and replaces) specific code blocks for query / key / value and output
|
|
projections:
|
|
|
|
```python
|
|
ORIGINAL_QKV_CODE = """
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
""".lstrip(
|
|
"\n"
|
|
)
|
|
|
|
ORIGINAL_O_CODE = """
|
|
attn_output = self.o_proj(attn_output)
|
|
""".lstrip(
|
|
"\n"
|
|
)
|
|
```
|
|
|
|
Is replaced with:
|
|
|
|
```python
|
|
PATCHED_QKV_CODE = """
|
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
|
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
|
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
|
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
|
""".lstrip(
|
|
"\n"
|
|
)
|
|
|
|
PATCHED_O_CODE = """
|
|
attn_output = self.apply_o(attn_output)
|
|
""".lstrip(
|
|
"\n"
|
|
)
|
|
```
|
|
|
|
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
|
|
|
|
We welcome testing of other model architectures and / or PRs to expand our patching
|
|
logic to be compatible with more of them.
|
|
|
|
</details>
|
|
|
|
::: {.callout-tip}
|
|
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
|
:::
|
|
|
|
## Usage
|
|
|
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
|
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
|
|
`lora_o_kernel` enable the fused query-key-value projection and optimized output
|
|
projection, respectively.
|
|
|
|
```yaml
|
|
lora_mlp_kernel: true
|
|
lora_qkv_kernel: true
|
|
lora_o_kernel: true
|
|
```
|
|
|
|
## Requirements
|
|
|
|
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
|
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
|
- Targeted LoRA adapters cannot use Dropout
|
|
- This may limit model expressivity / cause overfitting
|
|
- Targeted LoRA adapters cannot have bias terms
|
|
- This may limit model expressivity
|
|
|
|
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
|
be re-finetuned without these features in order to be useful.
|
|
|
|
## Implementation details
|
|
|
|
### Custom autograd functions
|
|
|
|
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
|
|
LoRA and base weight computations together and provides a single, efficient backward
|
|
pass for the entire MLP block.
|
|
|
|
For attention components, similar optimizations are provided through a function that
|
|
handles the query, key, and value projections, and a function that handles the output
|
|
projection. They are designed to work with the existing `transformers` attention
|
|
implementation via some monkey-patching logic.
|
|
|
|
### Triton kernels
|
|
|
|
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
|
|
improved speed and memory performance. These kernels handle both the forward and
|
|
backward passes.
|
|
|
|
### Integration
|
|
|
|
The custom autograd functions and Triton kernels are designed to work together. The
|
|
autograd function manages the high-level computation flow and gradient tracking, while
|
|
calling the Triton kernels for the activation function computation. During the backward
|
|
pass, the kernel computes both the activation output and the required gradients, which
|
|
the autograd function then uses to compute the final gradients for the entire
|
|
computation path.
|
|
|
|
## Future Work
|
|
|
|
- Support for additional model architectures
|
|
- Support for the FSDP setting
|
|
- Support for dropout and bias
|
|
- Additional operator fusions
|