Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
This commit is contained in:
@@ -305,6 +305,13 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
|
||||
127
docs/lora_optims.qmd
Normal file
127
docs/lora_optims.qmd
Normal file
@@ -0,0 +1,127 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
|
||||
<details>
|
||||
|
||||
The set of models we support is currently limited by our attention patching strategy,
|
||||
which assumes (and replaces) specific code blocks for query / key / value and output
|
||||
projections:
|
||||
|
||||
```python
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Is replaced with:
|
||||
|
||||
```python
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
|
||||
|
||||
We welcome testing of other model architectures and / or PRs to expand our patching
|
||||
logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
|
||||
`lora_o_kernel` enable the fused query-key-value projection and optimized output
|
||||
projection, respectively.
|
||||
|
||||
```yaml
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- This may limit model expressivity
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
|
||||
## Implementation details
|
||||
|
||||
### Custom autograd functions
|
||||
|
||||
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
|
||||
LoRA and base weight computations together and provides a single, efficient backward
|
||||
pass for the entire MLP block.
|
||||
|
||||
For attention components, similar optimizations are provided through a function that
|
||||
handles the query, key, and value projections, and a function that handles the output
|
||||
projection. They are designed to work with the existing `transformers` attention
|
||||
implementation via some monkey-patching logic.
|
||||
|
||||
### Triton kernels
|
||||
|
||||
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
|
||||
improved speed and memory performance. These kernels handle both the forward and
|
||||
backward passes.
|
||||
|
||||
### Integration
|
||||
|
||||
The custom autograd functions and Triton kernels are designed to work together. The
|
||||
autograd function manages the high-level computation flow and gradient tracking, while
|
||||
calling the Triton kernels for the activation function computation. During the backward
|
||||
pass, the kernel computes both the activation output and the required gradients, which
|
||||
the autograd function then uses to compute the final gradients for the entire
|
||||
computation path.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
Reference in New Issue
Block a user