chore: cleanup post release v0.16 (#3577)
* fix: remove unneeded debug log * fix: cleanup * feat: add dense gemma config and cleanup * feat: add cce support * update notes and set torch compile * fix patch for new number of return vals * fixes for gemma4 * fix packing bug * use updated cce for mm * fix: pass in kv cache func when avail for transformers 5.5 * feat: update examples with flex variant and readme * gemma4 lora attention kernels --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||
#
|
||||
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
|
||||
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
|
||||
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
|
||||
#
|
||||
# Key notes:
|
||||
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
|
||||
# Use sdp_attention instead.
|
||||
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
|
||||
# LoRA to the text backbone via lora_target_linear_modules regex.
|
||||
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
|
||||
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
|
||||
# via the transformers ExpertsInterface.
|
||||
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
|
||||
# (no `mlp.` prefix, unlike Qwen/Mixtral).
|
||||
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
|
||||
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
|
||||
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
|
||||
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
|
||||
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
|
||||
|
||||
base_model: google/gemma-4-26B-A4B
|
||||
|
||||
@@ -24,7 +17,7 @@ plugins:
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
torch_compile: false
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
@@ -54,12 +47,9 @@ lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders).
|
||||
# lora_target_modules is intentionally empty — all module targeting is done
|
||||
# via regex in lora_target_linear_modules below.
|
||||
lora_target_modules: []
|
||||
lora_target_linear_modules:
|
||||
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||
lora_target_parameters:
|
||||
@@ -73,7 +63,7 @@ lora_o_kernel: false
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project: gemma4-qlora
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
@@ -93,8 +83,7 @@ gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
|
||||
flash_attention: false
|
||||
# FA2 not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora-flex
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
flex_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
69
examples/gemma4/31b-qlora.yaml
Normal file
69
examples/gemma4/31b-qlora.yaml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: false
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
60
examples/gemma4/README.md
Normal file
60
examples/gemma4/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Finetune Google's Gemma 4 with Axolotl
|
||||
|
||||
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
|
||||
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora-flex.yaml
|
||||
```
|
||||
|
||||
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
|
||||
|
||||
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||
|
||||
## Flex Attention
|
||||
|
||||
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
|
||||
|
||||
```yaml
|
||||
torch_compile: true
|
||||
flex_attention: true
|
||||
```
|
||||
|
||||
This works for both the MoE and Dense model.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
|
||||
- **LoRA kernels**: Not supported due to KV-sharing layers.
|
||||
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
|
||||
)
|
||||
|
||||
@@ -100,6 +100,27 @@ class AxolotlTrainer(
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
|
||||
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
|
||||
# "the model handles num_items_in_batch internally" and skips the loss ÷
|
||||
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
|
||||
# (the gradient itself is still correct). Override to False when the model
|
||||
# doesn't actually consume num_items_in_batch.
|
||||
if self.model_accepts_loss_kwargs:
|
||||
model_to_check = self.accelerator.unwrap_model(self.model)
|
||||
if hasattr(model_to_check, "base_model"): # PEFT wrapper
|
||||
model_to_check = model_to_check.base_model
|
||||
if hasattr(model_to_check, "model"):
|
||||
model_to_check = model_to_check.model
|
||||
fwd = getattr(model_to_check, "forward", None)
|
||||
if fwd is not None:
|
||||
import inspect
|
||||
|
||||
params = inspect.signature(fwd).parameters
|
||||
if "num_items_in_batch" not in params:
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
@@ -383,13 +404,27 @@ class AxolotlTrainer(
|
||||
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||
# Inject zeros (= text token type) when not provided by the data collator.
|
||||
_model_type = getattr(getattr(model, "config", None), "model_type", None)
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
and _model_type == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
|
||||
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
|
||||
# from position_ids, but only when attention_mask is None. When sample
|
||||
# packing is active the collator provides an all-ones attention_mask that
|
||||
# prevents this detection — remove it so the model builds the correct
|
||||
# per-sequence causal masks.
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and _model_type in ("gemma4", "gemma3")
|
||||
and "attention_mask" in inputs
|
||||
and "position_ids" in inputs
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -44,6 +44,7 @@ plugins:
|
||||
- gemma3_text
|
||||
- gemma3n
|
||||
- gemma3n_text
|
||||
- gemma4
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
|
||||
|
||||
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
||||
|
||||
**Important limitations:**
|
||||
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
|
||||
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
|
||||
|
||||
## Limitations
|
||||
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
|
||||
@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_sonicmoe_lora_overhead(cls, data):
|
||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
||||
"lora",
|
||||
"qlora",
|
||||
):
|
||||
lora_target = data.get("lora_target_modules") or []
|
||||
lora_linear = data.get("lora_target_linear_modules") or []
|
||||
targets = (
|
||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
||||
if any(kw in t for t in targets for kw in expert_keywords):
|
||||
LOG.info(
|
||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel(cls, data):
|
||||
|
||||
@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_QK(torch.autograd.Function):
|
||||
"""Optimized LoRA QK implementation for models where v_proj is None.
|
||||
|
||||
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
|
||||
reused as value states. Only Q and K projections are fused; the caller
|
||||
returns K a second time as V so that autograd accumulates key+value gradients
|
||||
into a single dK.
|
||||
|
||||
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
X_drop: torch.Tensor | None,
|
||||
# Q params
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor | None,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
q_lora_bias: torch.Tensor | None,
|
||||
q_magnitude: torch.Tensor | None,
|
||||
# K params
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor | None,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
k_lora_bias: torch.Tensor | None,
|
||||
k_magnitude: torch.Tensor | None,
|
||||
# Flags
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
has_dropout = X_drop is not None
|
||||
has_dora = q_magnitude is not None
|
||||
|
||||
if has_dora:
|
||||
dtype = X.dtype
|
||||
X_lora = X_drop if has_dropout else X
|
||||
|
||||
# Compute Q with DoRA
|
||||
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
|
||||
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
|
||||
q_mag_scale = _compute_dora_scale(
|
||||
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
|
||||
)
|
||||
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
|
||||
if q_bias is not None:
|
||||
Q = Q + q_bias
|
||||
|
||||
# Compute K with DoRA
|
||||
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
|
||||
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
|
||||
k_mag_scale = _compute_dora_scale(
|
||||
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
|
||||
)
|
||||
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
|
||||
if k_bias is not None:
|
||||
K = K + k_bias
|
||||
|
||||
Q_combined = Q_base + Q_lora
|
||||
K_combined = K_base + K_lora
|
||||
|
||||
ctx.save_for_backward(
|
||||
X,
|
||||
X_drop if has_dropout else X,
|
||||
q_A.to(dtype) if q_A is not None else q_A,
|
||||
q_B.to(dtype) if q_B is not None else q_B,
|
||||
k_A.to(dtype) if k_A is not None else k_A,
|
||||
k_B.to(dtype) if k_B is not None else k_B,
|
||||
q_magnitude,
|
||||
k_magnitude,
|
||||
q_mag_scale,
|
||||
k_mag_scale,
|
||||
Q_combined,
|
||||
K_combined,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
)
|
||||
else:
|
||||
# Standard LoRA (with optional dropout and bias)
|
||||
Q = matmul_lora(
|
||||
X,
|
||||
q_weight,
|
||||
q_bias,
|
||||
q_quant,
|
||||
q_A,
|
||||
q_B,
|
||||
q_scale,
|
||||
X_drop=X_drop,
|
||||
lora_bias=q_lora_bias,
|
||||
)
|
||||
K = matmul_lora(
|
||||
X,
|
||||
k_weight,
|
||||
k_bias,
|
||||
k_quant,
|
||||
k_A,
|
||||
k_B,
|
||||
k_scale,
|
||||
X_drop=X_drop,
|
||||
lora_bias=k_lora_bias,
|
||||
)
|
||||
|
||||
dtype = X.dtype
|
||||
ctx.save_for_backward(
|
||||
X,
|
||||
X_drop if has_dropout else X,
|
||||
q_A.to(dtype) if q_A is not None else q_A,
|
||||
q_B.to(dtype) if q_B is not None else q_B,
|
||||
k_A.to(dtype) if k_A is not None else k_A,
|
||||
k_B.to(dtype) if k_B is not None else k_B,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
)
|
||||
|
||||
ctx.scales = (q_scale, k_scale)
|
||||
ctx.quants = (q_quant, k_quant)
|
||||
ctx.weights = (q_weight, k_weight)
|
||||
ctx.inplace = inplace
|
||||
ctx.has_dropout = has_dropout
|
||||
ctx.has_dora = has_dora
|
||||
|
||||
return Q, K
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
):
|
||||
q_weight, k_weight = ctx.weights
|
||||
q_quant, k_quant = ctx.quants
|
||||
q_scale, k_scale = ctx.scales
|
||||
has_dropout = ctx.has_dropout
|
||||
has_dora = ctx.has_dora
|
||||
|
||||
if has_dora:
|
||||
(
|
||||
X,
|
||||
X_lora,
|
||||
A_q,
|
||||
B_q,
|
||||
A_k,
|
||||
B_k,
|
||||
q_magnitude,
|
||||
k_magnitude,
|
||||
q_mag_scale,
|
||||
k_mag_scale,
|
||||
Q_combined,
|
||||
K_combined,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
) = ctx.saved_tensors
|
||||
else:
|
||||
(
|
||||
X,
|
||||
X_lora,
|
||||
A_q,
|
||||
B_q,
|
||||
A_k,
|
||||
B_k,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
) = ctx.saved_tensors
|
||||
q_magnitude = k_magnitude = None
|
||||
q_mag_scale = k_mag_scale = None
|
||||
Q_combined = K_combined = None
|
||||
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
X_lora = X_lora.view(-1, X_lora.shape[-1])
|
||||
|
||||
d_q_mag = d_k_mag = None
|
||||
d_q_lora_bias = d_k_lora_bias = None
|
||||
|
||||
if has_dora:
|
||||
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
|
||||
K_combined = K_combined.view(-1, K_combined.shape[-1])
|
||||
|
||||
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
|
||||
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
|
||||
|
||||
q_grad = q_grad * q_mag_scale.unsqueeze(0)
|
||||
k_grad = k_grad * k_mag_scale.unsqueeze(0)
|
||||
|
||||
# LoRA bias gradients
|
||||
if q_lora_bias is not None:
|
||||
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
|
||||
if k_lora_bias is not None:
|
||||
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
|
||||
|
||||
X_lora_t = X_lora.t()
|
||||
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = None
|
||||
grad_B_q = grad_B_k = None
|
||||
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_B_q = q_grad @ B_q
|
||||
d_A_q = torch.empty_like(A_q.t())
|
||||
d_B_q = torch.empty_like(B_q.t())
|
||||
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
|
||||
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
|
||||
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_B_k = k_grad @ B_k
|
||||
d_A_k = torch.empty_like(A_k.t())
|
||||
d_B_k = torch.empty_like(B_k.t())
|
||||
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
|
||||
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
|
||||
|
||||
# Base path input gradient
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight_t
|
||||
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight_t
|
||||
|
||||
# LoRA path input gradient
|
||||
if has_dropout:
|
||||
grad_X_drop = torch.zeros_like(X_lora)
|
||||
if grad_B_q is not None:
|
||||
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||
if grad_B_k is not None:
|
||||
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||
else:
|
||||
grad_X_drop = None
|
||||
if grad_B_q is not None:
|
||||
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||
if grad_B_k is not None:
|
||||
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||
|
||||
grad_X = grad_X.view(batch, seq_len, -1)
|
||||
if grad_X_drop is not None:
|
||||
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
|
||||
|
||||
# Return gradients for all forward inputs:
|
||||
# X, X_drop,
|
||||
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
# inplace
|
||||
return (
|
||||
grad_X,
|
||||
grad_X_drop,
|
||||
# Q
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
d_q_lora_bias,
|
||||
d_q_mag,
|
||||
# K
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
d_k_lora_bias,
|
||||
d_k_mag,
|
||||
# inplace
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qk(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query and Key projections for models where v_proj is None.
|
||||
|
||||
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
|
||||
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
|
||||
Because K is returned twice, autograd accumulates gradients from both the key and
|
||||
value paths into dK before calling LoRA_QK.backward.
|
||||
|
||||
Supports bias, dropout, and DoRA.
|
||||
"""
|
||||
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
|
||||
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
|
||||
|
||||
# Apply dropout outside autograd.Function (shared mask for Q, K)
|
||||
X_drop = _apply_dropout(Qdrop, X, self.training)
|
||||
|
||||
Q, K = LoRA_QK.apply(
|
||||
X,
|
||||
X_drop,
|
||||
# Q
|
||||
QW,
|
||||
Qb,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
Qlb,
|
||||
Qmag,
|
||||
# K
|
||||
KW,
|
||||
Kb,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
Klb,
|
||||
Kmag,
|
||||
# Flags
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, K
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection.
|
||||
|
||||
|
||||
@@ -67,12 +67,70 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def _patch_peft_clippable_linear():
|
||||
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
|
||||
|
||||
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
|
||||
that clips activations). PEFT doesn't recognise it as a supported layer type,
|
||||
so we redirect LoRA injection to the inner ``.linear`` child instead.
|
||||
"""
|
||||
try:
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4ClippableLinear as _cls,
|
||||
)
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
from peft.tuners.lora.model import LoraModel
|
||||
|
||||
if getattr(LoraModel, "_axolotl_clippable_patched", False):
|
||||
return
|
||||
_orig = LoraModel._create_and_replace
|
||||
|
||||
def _patched(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=None,
|
||||
**kw,
|
||||
):
|
||||
if isinstance(target, _cls):
|
||||
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
|
||||
return _orig(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target.linear,
|
||||
"linear",
|
||||
target,
|
||||
current_key=current_key,
|
||||
**kw,
|
||||
)
|
||||
return _orig(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=current_key,
|
||||
**kw,
|
||||
)
|
||||
|
||||
LoraModel._create_and_replace = _patched
|
||||
LoraModel._axolotl_clippable_patched = True
|
||||
|
||||
|
||||
def load_lora(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
_patch_peft_clippable_linear()
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
@@ -124,6 +182,7 @@ def load_lora(
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=task_type,
|
||||
**lora_config_kwargs,
|
||||
|
||||
@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
|
||||
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||
return
|
||||
|
||||
try:
|
||||
# flash-attn-4>=4.0.0b7
|
||||
from flash_attn.cute import flash_attn_with_kvcache
|
||||
except ImportError:
|
||||
flash_attn_with_kvcache = None
|
||||
|
||||
def _patched_lazy_imports(
|
||||
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||
):
|
||||
return (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
fa_utils._pad_input,
|
||||
fa_utils._unpad_input,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qk,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
@@ -483,18 +484,24 @@ def apply_lora_kernel_patches(
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
|
||||
proj_names = ["q_proj", "k_proj", "v_proj"]
|
||||
layer_modules = [
|
||||
getattr(self_attn, name)
|
||||
for name in proj_names
|
||||
if getattr(self_attn, name, None) is not None
|
||||
]
|
||||
has_v_proj = getattr(self_attn, "v_proj", None) is not None
|
||||
proj_names = (
|
||||
["q_proj", "k_proj", "v_proj"]
|
||||
if has_v_proj
|
||||
else ["q_proj", "k_proj"]
|
||||
)
|
||||
layer_modules = [getattr(self_attn, name) for name in proj_names]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A") for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
if has_v_proj:
|
||||
self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, self_attn
|
||||
)
|
||||
else:
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||
|
||||
Reference in New Issue
Block a user