chore: cleanup post release v0.16 (#3577)

* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-07 00:10:52 +07:00
committed by GitHub
parent dc638e723f
commit 149178ddb7
15 changed files with 665 additions and 60 deletions

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
},
{

View File

@@ -1,19 +1,12 @@
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
#
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
#
# Key notes:
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
# Use sdp_attention instead.
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
# LoRA to the text backbone via lora_target_linear_modules regex.
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
# via the transformers ExpertsInterface.
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
# (no `mlp.` prefix, unlike Qwen/Mixtral).
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
base_model: google/gemma-4-26B-A4B
@@ -24,7 +17,7 @@ plugins:
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
torch_compile: false
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
@@ -54,12 +47,9 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders).
# lora_target_modules is intentionally empty — all module targeting is done
# via regex in lora_target_linear_modules below.
lora_target_modules: []
lora_target_linear_modules:
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
lora_target_parameters:
@@ -73,7 +63,7 @@ lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project: gemma4-qlora
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
@@ -93,8 +83,7 @@ gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
flash_attention: false
# FA2 not supported
sdp_attention: true
warmup_ratio: 0.1

View File

@@ -0,0 +1,71 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora-flex
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
flex_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,69 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: false
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

60
examples/gemma4/README.md Normal file
View File

@@ -0,0 +1,60 @@
# Finetune Google's Gemma 4 with Axolotl
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
axolotl train examples/gemma4/31b-qlora.yaml
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
axolotl train examples/gemma4/31b-qlora-flex.yaml
```
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Flex Attention
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
```yaml
torch_compile: true
flex_attention: true
```
This works for both the MoE and Dense model.
## Limitations
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
- **LoRA kernels**: Not supported due to KV-sharing layers.
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
)

View File

@@ -100,6 +100,27 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
# "the model handles num_items_in_batch internally" and skips the loss ÷
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
# (the gradient itself is still correct). Override to False when the model
# doesn't actually consume num_items_in_batch.
if self.model_accepts_loss_kwargs:
model_to_check = self.accelerator.unwrap_model(self.model)
if hasattr(model_to_check, "base_model"): # PEFT wrapper
model_to_check = model_to_check.base_model
if hasattr(model_to_check, "model"):
model_to_check = model_to_check.model
fwd = getattr(model_to_check, "forward", None)
if fwd is not None:
import inspect
params = inspect.signature(fwd).parameters
if "num_items_in_batch" not in params:
self.model_accepts_loss_kwargs = False
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -383,13 +404,27 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
_model_type = getattr(getattr(model, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
# from position_ids, but only when attention_mask is None. When sample
# packing is active the collator provides an all-ones attention_mask that
# prevents this detection — remove it so the model builds the correct
# per-sequence causal masks.
if (
self.args.sample_packing
and _model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3_text
- gemma3n
- gemma3n_text
- gemma4
- glm
- glm4
- glm4_moe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
)

View File

@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
return Q, K, V
class LoRA_QK(torch.autograd.Function):
"""Optimized LoRA QK implementation for models where v_proj is None.
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
reused as value states. Only Q and K projections are fused; the caller
returns K a second time as V so that autograd accumulates key+value gradients
into a single dK.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
X_drop: torch.Tensor | None,
# Q params
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
q_lora_bias: torch.Tensor | None,
q_magnitude: torch.Tensor | None,
# K params
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
k_lora_bias: torch.Tensor | None,
k_magnitude: torch.Tensor | None,
# Flags
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
has_dropout = X_drop is not None
has_dora = q_magnitude is not None
if has_dora:
dtype = X.dtype
X_lora = X_drop if has_dropout else X
# Compute Q with DoRA
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
q_mag_scale = _compute_dora_scale(
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
)
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
if q_bias is not None:
Q = Q + q_bias
# Compute K with DoRA
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
k_mag_scale = _compute_dora_scale(
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
)
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
if k_bias is not None:
K = K + k_bias
Q_combined = Q_base + Q_lora
K_combined = K_base + K_lora
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
)
else:
# Standard LoRA (with optional dropout and bias)
Q = matmul_lora(
X,
q_weight,
q_bias,
q_quant,
q_A,
q_B,
q_scale,
X_drop=X_drop,
lora_bias=q_lora_bias,
)
K = matmul_lora(
X,
k_weight,
k_bias,
k_quant,
k_A,
k_B,
k_scale,
X_drop=X_drop,
lora_bias=k_lora_bias,
)
dtype = X.dtype
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_lora_bias,
k_lora_bias,
)
ctx.scales = (q_scale, k_scale)
ctx.quants = (q_quant, k_quant)
ctx.weights = (q_weight, k_weight)
ctx.inplace = inplace
ctx.has_dropout = has_dropout
ctx.has_dora = has_dora
return Q, K
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
):
q_weight, k_weight = ctx.weights
q_quant, k_quant = ctx.quants
q_scale, k_scale = ctx.scales
has_dropout = ctx.has_dropout
has_dora = ctx.has_dora
if has_dora:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
else:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
q_magnitude = k_magnitude = None
q_mag_scale = k_mag_scale = None
Q_combined = K_combined = None
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
X = X.view(-1, X.shape[-1])
X_lora = X_lora.view(-1, X_lora.shape[-1])
d_q_mag = d_k_mag = None
d_q_lora_bias = d_k_lora_bias = None
if has_dora:
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
K_combined = K_combined.view(-1, K_combined.shape[-1])
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
q_grad = q_grad * q_mag_scale.unsqueeze(0)
k_grad = k_grad * k_mag_scale.unsqueeze(0)
# LoRA bias gradients
if q_lora_bias is not None:
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
if k_lora_bias is not None:
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
X_lora_t = X_lora.t()
d_A_q = d_B_q = d_A_k = d_B_k = None
grad_B_q = grad_B_k = None
if A_q is not None and B_q is not None:
grad_B_q = q_grad @ B_q
d_A_q = torch.empty_like(A_q.t())
d_B_q = torch.empty_like(B_q.t())
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
if A_k is not None and B_k is not None:
grad_B_k = k_grad @ B_k
d_A_k = torch.empty_like(A_k.t())
d_B_k = torch.empty_like(B_k.t())
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
# Base path input gradient
out_buffer = X if ctx.inplace else None
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight_t
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight_t
# LoRA path input gradient
if has_dropout:
grad_X_drop = torch.zeros_like(X_lora)
if grad_B_q is not None:
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
else:
grad_X_drop = None
if grad_B_q is not None:
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
grad_X = grad_X.view(batch, seq_len, -1)
if grad_X_drop is not None:
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
# Return gradients for all forward inputs:
# X, X_drop,
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
# inplace
return (
grad_X,
grad_X_drop,
# Q
None,
None,
None,
d_A_q,
d_B_q,
None,
d_q_lora_bias,
d_q_mag,
# K
None,
None,
None,
d_A_k,
d_B_k,
None,
d_k_lora_bias,
d_k_mag,
# inplace
None,
)
def apply_lora_qk(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query and Key projections for models where v_proj is None.
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
Because K is returned twice, autograd accumulates gradients from both the key and
value paths into dK before calling LoRA_QK.backward.
Supports bias, dropout, and DoRA.
"""
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
# Apply dropout outside autograd.Function (shared mask for Q, K)
X_drop = _apply_dropout(Qdrop, X, self.training)
Q, K = LoRA_QK.apply(
X,
X_drop,
# Q
QW,
Qb,
QW_quant,
QA,
QB,
QS,
Qlb,
Qmag,
# K
KW,
Kb,
KW_quant,
KA,
KB,
KS,
Klb,
Kmag,
# Flags
inplace,
)
return Q, K, K
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection.

View File

@@ -67,12 +67,70 @@ def find_all_linear_names(model):
return list(lora_module_names)
def _patch_peft_clippable_linear():
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
that clips activations). PEFT doesn't recognise it as a supported layer type,
so we redirect LoRA injection to the inner ``.linear`` child instead.
"""
try:
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ClippableLinear as _cls,
)
except ImportError:
return
from peft.tuners.lora.model import LoraModel
if getattr(LoraModel, "_axolotl_clippable_patched", False):
return
_orig = LoraModel._create_and_replace
def _patched(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=None,
**kw,
):
if isinstance(target, _cls):
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
return _orig(
self,
peft_config,
adapter_name,
target.linear,
"linear",
target,
current_key=current_key,
**kw,
)
return _orig(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=current_key,
**kw,
)
LoraModel._create_and_replace = _patched
LoraModel._axolotl_clippable_patched = True
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
@@ -124,6 +182,7 @@ def load_lora(
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,

View File

@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
try:
# flash-attn-4>=4.0.0b7
from flash_attn.cute import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
fa_utils._pad_input,
fa_utils._unpad_input,
)

View File

@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qk,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
@@ -483,18 +484,24 @@ def apply_lora_kernel_patches(
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
has_v_proj = getattr(self_attn, "v_proj", None) is not None
proj_names = (
["q_proj", "k_proj", "v_proj"]
if has_v_proj
else ["q_proj", "k_proj"]
)
layer_modules = [getattr(self_attn, name) for name in proj_names]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
if has_v_proj:
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"