chore: cleanup post release v0.16 (#3577)

* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-07 00:10:52 +07:00
committed by GitHub
parent dc638e723f
commit 149178ddb7
15 changed files with 665 additions and 60 deletions

View File

@@ -100,6 +100,27 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
# "the model handles num_items_in_batch internally" and skips the loss ÷
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
# (the gradient itself is still correct). Override to False when the model
# doesn't actually consume num_items_in_batch.
if self.model_accepts_loss_kwargs:
model_to_check = self.accelerator.unwrap_model(self.model)
if hasattr(model_to_check, "base_model"): # PEFT wrapper
model_to_check = model_to_check.base_model
if hasattr(model_to_check, "model"):
model_to_check = model_to_check.model
fwd = getattr(model_to_check, "forward", None)
if fwd is not None:
import inspect
params = inspect.signature(fwd).parameters
if "num_items_in_batch" not in params:
self.model_accepts_loss_kwargs = False
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -383,13 +404,27 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
_model_type = getattr(getattr(model, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
# from position_ids, but only when attention_mask is None. When sample
# packing is active the collator provides an all-ones attention_mask that
# prevents this detection — remove it so the model builds the correct
# per-sequence causal masks.
if (
self.args.sample_packing
and _model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3_text
- gemma3n
- gemma3n_text
- gemma4
- glm
- glm4
- glm4_moe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
)

View File

@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
return Q, K, V
class LoRA_QK(torch.autograd.Function):
"""Optimized LoRA QK implementation for models where v_proj is None.
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
reused as value states. Only Q and K projections are fused; the caller
returns K a second time as V so that autograd accumulates key+value gradients
into a single dK.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
X_drop: torch.Tensor | None,
# Q params
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
q_lora_bias: torch.Tensor | None,
q_magnitude: torch.Tensor | None,
# K params
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
k_lora_bias: torch.Tensor | None,
k_magnitude: torch.Tensor | None,
# Flags
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
has_dropout = X_drop is not None
has_dora = q_magnitude is not None
if has_dora:
dtype = X.dtype
X_lora = X_drop if has_dropout else X
# Compute Q with DoRA
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
q_mag_scale = _compute_dora_scale(
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
)
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
if q_bias is not None:
Q = Q + q_bias
# Compute K with DoRA
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
k_mag_scale = _compute_dora_scale(
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
)
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
if k_bias is not None:
K = K + k_bias
Q_combined = Q_base + Q_lora
K_combined = K_base + K_lora
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
)
else:
# Standard LoRA (with optional dropout and bias)
Q = matmul_lora(
X,
q_weight,
q_bias,
q_quant,
q_A,
q_B,
q_scale,
X_drop=X_drop,
lora_bias=q_lora_bias,
)
K = matmul_lora(
X,
k_weight,
k_bias,
k_quant,
k_A,
k_B,
k_scale,
X_drop=X_drop,
lora_bias=k_lora_bias,
)
dtype = X.dtype
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_lora_bias,
k_lora_bias,
)
ctx.scales = (q_scale, k_scale)
ctx.quants = (q_quant, k_quant)
ctx.weights = (q_weight, k_weight)
ctx.inplace = inplace
ctx.has_dropout = has_dropout
ctx.has_dora = has_dora
return Q, K
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
):
q_weight, k_weight = ctx.weights
q_quant, k_quant = ctx.quants
q_scale, k_scale = ctx.scales
has_dropout = ctx.has_dropout
has_dora = ctx.has_dora
if has_dora:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
else:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
q_magnitude = k_magnitude = None
q_mag_scale = k_mag_scale = None
Q_combined = K_combined = None
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
X = X.view(-1, X.shape[-1])
X_lora = X_lora.view(-1, X_lora.shape[-1])
d_q_mag = d_k_mag = None
d_q_lora_bias = d_k_lora_bias = None
if has_dora:
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
K_combined = K_combined.view(-1, K_combined.shape[-1])
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
q_grad = q_grad * q_mag_scale.unsqueeze(0)
k_grad = k_grad * k_mag_scale.unsqueeze(0)
# LoRA bias gradients
if q_lora_bias is not None:
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
if k_lora_bias is not None:
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
X_lora_t = X_lora.t()
d_A_q = d_B_q = d_A_k = d_B_k = None
grad_B_q = grad_B_k = None
if A_q is not None and B_q is not None:
grad_B_q = q_grad @ B_q
d_A_q = torch.empty_like(A_q.t())
d_B_q = torch.empty_like(B_q.t())
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
if A_k is not None and B_k is not None:
grad_B_k = k_grad @ B_k
d_A_k = torch.empty_like(A_k.t())
d_B_k = torch.empty_like(B_k.t())
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
# Base path input gradient
out_buffer = X if ctx.inplace else None
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight_t
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight_t
# LoRA path input gradient
if has_dropout:
grad_X_drop = torch.zeros_like(X_lora)
if grad_B_q is not None:
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
else:
grad_X_drop = None
if grad_B_q is not None:
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
grad_X = grad_X.view(batch, seq_len, -1)
if grad_X_drop is not None:
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
# Return gradients for all forward inputs:
# X, X_drop,
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
# inplace
return (
grad_X,
grad_X_drop,
# Q
None,
None,
None,
d_A_q,
d_B_q,
None,
d_q_lora_bias,
d_q_mag,
# K
None,
None,
None,
d_A_k,
d_B_k,
None,
d_k_lora_bias,
d_k_mag,
# inplace
None,
)
def apply_lora_qk(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query and Key projections for models where v_proj is None.
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
Because K is returned twice, autograd accumulates gradients from both the key and
value paths into dK before calling LoRA_QK.backward.
Supports bias, dropout, and DoRA.
"""
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
# Apply dropout outside autograd.Function (shared mask for Q, K)
X_drop = _apply_dropout(Qdrop, X, self.training)
Q, K = LoRA_QK.apply(
X,
X_drop,
# Q
QW,
Qb,
QW_quant,
QA,
QB,
QS,
Qlb,
Qmag,
# K
KW,
Kb,
KW_quant,
KA,
KB,
KS,
Klb,
Kmag,
# Flags
inplace,
)
return Q, K, K
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection.

View File

@@ -67,12 +67,70 @@ def find_all_linear_names(model):
return list(lora_module_names)
def _patch_peft_clippable_linear():
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
that clips activations). PEFT doesn't recognise it as a supported layer type,
so we redirect LoRA injection to the inner ``.linear`` child instead.
"""
try:
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ClippableLinear as _cls,
)
except ImportError:
return
from peft.tuners.lora.model import LoraModel
if getattr(LoraModel, "_axolotl_clippable_patched", False):
return
_orig = LoraModel._create_and_replace
def _patched(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=None,
**kw,
):
if isinstance(target, _cls):
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
return _orig(
self,
peft_config,
adapter_name,
target.linear,
"linear",
target,
current_key=current_key,
**kw,
)
return _orig(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=current_key,
**kw,
)
LoraModel._create_and_replace = _patched
LoraModel._axolotl_clippable_patched = True
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
@@ -124,6 +182,7 @@ def load_lora(
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,

View File

@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
try:
# flash-attn-4>=4.0.0b7
from flash_attn.cute import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
fa_utils._pad_input,
fa_utils._unpad_input,
)

View File

@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qk,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
@@ -483,18 +484,24 @@ def apply_lora_kernel_patches(
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
has_v_proj = getattr(self_attn, "v_proj", None) is not None
proj_names = (
["q_proj", "k_proj", "v_proj"]
if has_v_proj
else ["q_proj", "k_proj"]
)
layer_modules = [getattr(self_attn, name) for name in proj_names]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
if has_v_proj:
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"