Compare commits
2 Commits
08fc7de87e
...
lora-fsdp2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3299f182ba | ||
|
|
2fc430d365 |
@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
|
||||
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
|
||||
functions. Our goal was to leverage operator fusion and tensor re-use in order to
|
||||
improve speed and reduce memory usage during the forward and backward passes of these
|
||||
calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
|
||||
- This may limit model expressivity
|
||||
- Adapters that already include bias terms are supported.
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
|
||||
without it in order to be as performant.
|
||||
|
||||
## Implementation details
|
||||
|
||||
@@ -131,6 +131,5 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Support for dropout
|
||||
- Additional operator fusions
|
||||
|
||||
@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
The optimizations require LoRA adapters with no dropout. The function will skip
|
||||
patching if that condition isn't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
can_patch = lora_config.lora_dropout == 0
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
|
||||
@@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation):
|
||||
assert layer.mlp.forward.__func__ is expected_activation
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
},
|
||||
]
|
||||
def test_kernel_patch_requires_zero_dropout():
|
||||
"""Kernel patching should be skipped when dropout is enabled."""
|
||||
config = {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
}
|
||||
|
||||
for config in test_configs:
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
# Verify no patches applied when dropout is non-zero
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
|
||||
|
||||
def test_kernel_patch_with_bias_enabled():
|
||||
"""Kernel patching should succeed when LoRA bias is enabled."""
|
||||
config = {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify patches applied when bias support is enabled
|
||||
assert layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
|
||||
Reference in New Issue
Block a user