Compare commits
2 Commits
fix/diffus
...
lora-fsdp2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3299f182ba | ||
|
|
2fc430d365 |
@@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
|||||||
|
|
||||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
|
||||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
|
||||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
functions. Our goal was to leverage operator fusion and tensor re-use in order to
|
||||||
memory usage during the forward and backward passes of these calculations.
|
improve speed and reduce memory usage during the forward and backward passes of these
|
||||||
|
calculations.
|
||||||
|
|
||||||
We currently support several common model architectures, including (but not limited to):
|
We currently support several common model architectures, including (but not limited to):
|
||||||
|
|
||||||
@@ -92,13 +93,12 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
|||||||
|
|
||||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||||
- Targeted LoRA adapters cannot use Dropout
|
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
|
||||||
- This may limit model expressivity / cause overfitting
|
|
||||||
- Targeted LoRA adapters cannot have bias terms
|
|
||||||
- This may limit model expressivity
|
- This may limit model expressivity
|
||||||
|
- Adapters that already include bias terms are supported.
|
||||||
|
|
||||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
|
||||||
be re-finetuned without these features in order to be useful.
|
without it in order to be as performant.
|
||||||
|
|
||||||
## Implementation details
|
## Implementation details
|
||||||
|
|
||||||
@@ -131,6 +131,5 @@ computation path.
|
|||||||
## Future Work
|
## Future Work
|
||||||
|
|
||||||
- Support for additional model architectures
|
- Support for additional model architectures
|
||||||
- Support for the FSDP setting
|
- Support for dropout
|
||||||
- Support for dropout and bias
|
|
||||||
- Additional operator fusions
|
- Additional operator fusions
|
||||||
|
|||||||
@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
|
|||||||
AssertionError: If multiple adapters are active (currently unsupported).
|
AssertionError: If multiple adapters are active (currently unsupported).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
The optimizations require LoRA adapters with no dropout. The function will skip
|
||||||
function will skip patching if these conditions aren't met.
|
patching if that condition isn't met.
|
||||||
"""
|
"""
|
||||||
if not isinstance(model, PeftModelForCausalLM):
|
if not isinstance(model, PeftModelForCausalLM):
|
||||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||||
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
|
|||||||
lora_config = model.model.peft_config[active_adapter]
|
lora_config = model.model.peft_config[active_adapter]
|
||||||
|
|
||||||
# Only patch if conditions are met
|
# Only patch if conditions are met
|
||||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
can_patch = lora_config.lora_dropout == 0
|
||||||
|
|
||||||
if not can_patch:
|
if not can_patch:
|
||||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
|
||||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -221,44 +221,53 @@ def test_model_specific_activation(model_name, expected_activation):
|
|||||||
assert layer.mlp.forward.__func__ is expected_activation
|
assert layer.mlp.forward.__func__ is expected_activation
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_patch_conditions():
|
def test_kernel_patch_requires_zero_dropout():
|
||||||
"""Test various conditions that should prevent kernel patching."""
|
"""Kernel patching should be skipped when dropout is enabled."""
|
||||||
test_configs = [
|
config = {
|
||||||
# Dropout prevents patching
|
"peft_type": "LORA",
|
||||||
{
|
"task_type": "CAUSAL_LM",
|
||||||
"peft_type": "LORA",
|
"r": 8,
|
||||||
"task_type": "CAUSAL_LM",
|
"lora_alpha": 16,
|
||||||
"r": 8,
|
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||||
"lora_alpha": 16,
|
"lora_dropout": 0.1,
|
||||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
"bias": "none",
|
||||||
"lora_dropout": 0.1,
|
}
|
||||||
"bias": "none",
|
|
||||||
},
|
|
||||||
# Bias prevents patching
|
|
||||||
{
|
|
||||||
"peft_type": "LORA",
|
|
||||||
"task_type": "CAUSAL_LM",
|
|
||||||
"r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
|
||||||
"lora_dropout": 0,
|
|
||||||
"bias": "lora_only",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for config in test_configs:
|
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
peft_config = get_peft_config(config)
|
||||||
peft_config = get_peft_config(config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
model = PeftModelForCausalLM(model, peft_config)
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
|
||||||
|
|
||||||
# Should not patch
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
layer = patched_model.model.model.layers[0].mlp
|
|
||||||
|
|
||||||
# Verify no patches applied
|
# Verify no patches applied when dropout is non-zero
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_patch_with_bias_enabled():
|
||||||
|
"""Kernel patching should succeed when LoRA bias is enabled."""
|
||||||
|
config = {
|
||||||
|
"peft_type": "LORA",
|
||||||
|
"task_type": "CAUSAL_LM",
|
||||||
|
"r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||||
|
"lora_dropout": 0,
|
||||||
|
"bias": "lora_only",
|
||||||
|
}
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
peft_config = get_peft_config(config)
|
||||||
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
|
|
||||||
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
|
|
||||||
|
# Verify patches applied when bias support is enabled
|
||||||
|
assert layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_config_options():
|
def test_kernel_config_options():
|
||||||
|
|||||||
Reference in New Issue
Block a user