diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 7a9feaa03..c7b2b8e5b 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\"" ] }, { diff --git a/examples/gemma4/26b-a4b-moe-qlora.yaml b/examples/gemma4/26b-a4b-moe-qlora.yaml index 0972b93f6..e7bdb6f46 100644 --- a/examples/gemma4/26b-a4b-moe-qlora.yaml +++ b/examples/gemma4/26b-a4b-moe-qlora.yaml @@ -1,19 +1,12 @@ # Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels # -# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB) +# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB) +# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total # # Key notes: -# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256). -# Use sdp_attention instead. -# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict -# LoRA to the text backbone via lora_target_linear_modules regex. -# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE -# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE -# via the transformers ExpertsInterface. -# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj` -# (no `mlp.` prefix, unlike Qwen/Mixtral). -# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention. -# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs. +# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention). +# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix. +# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP. base_model: google/gemma-4-26B-A4B @@ -24,7 +17,7 @@ plugins: use_kernels: true use_scattermoe: true experts_implementation: scattermoe -torch_compile: false +torch_compile: true liger_layer_norm: true liger_rope: true liger_rms_norm: true @@ -54,12 +47,9 @@ lora_r: 16 lora_alpha: 32 lora_dropout: 0 -# Restrict LoRA to text backbone only (skip vision/audio encoders). -# lora_target_modules is intentionally empty — all module targeting is done -# via regex in lora_target_linear_modules below. -lora_target_modules: [] -lora_target_linear_modules: - - language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj +# Restrict LoRA to text backbone only (skip vision/audio encoders) +# using regex to match only the text decoder attention projections. +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' # MoE expert LoRA (3D Parameter tensors, not nn.Linear) lora_target_parameters: @@ -73,7 +63,7 @@ lora_o_kernel: false bnb_config_kwargs: bnb_4bit_use_double_quant: true -wandb_project: gemma4-qlora +wandb_project: wandb_entity: wandb_watch: wandb_name: @@ -93,8 +83,7 @@ gradient_checkpointing: true activation_offloading: true logging_steps: 1 -# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256 -flash_attention: false +# FA2 not supported sdp_attention: true warmup_ratio: 0.1 diff --git a/examples/gemma4/31b-qlora-flex.yaml b/examples/gemma4/31b-qlora-flex.yaml new file mode 100644 index 000000000..8456c9c13 --- /dev/null +++ b/examples/gemma4/31b-qlora-flex.yaml @@ -0,0 +1,71 @@ +base_model: google/gemma-4-31B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.liger.LigerPlugin +torch_compile: true +liger_layer_norm: true +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_rms_norm_gated: true +strict: false + +chat_template: gemma4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:10%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.05 +output_dir: ./outputs/gemma4-31b-qlora-flex + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 + +# Restrict LoRA to text backbone only (skip vision/audio encoders) +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +bnb_config_kwargs: + bnb_4bit_use_double_quant: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: true +logging_steps: 1 + +# FA not supported +flex_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma4/31b-qlora.yaml b/examples/gemma4/31b-qlora.yaml new file mode 100644 index 000000000..42086a43c --- /dev/null +++ b/examples/gemma4/31b-qlora.yaml @@ -0,0 +1,69 @@ +base_model: google/gemma-4-31B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.liger.LigerPlugin +torch_compile: false +liger_layer_norm: true +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_rms_norm_gated: true +strict: false + +chat_template: gemma4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:10%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.05 +output_dir: ./outputs/gemma4-31b-qlora + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 + +# Restrict LoRA to text backbone only (skip vision/audio encoders) +# using regex to match only the text decoder attention projections. +lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +bnb_config_kwargs: + bnb_4bit_use_double_quant: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: true +logging_steps: 1 + +# FA not supported +sdp_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma4/README.md b/examples/gemma4/README.md new file mode 100644 index 000000000..68274ee68 --- /dev/null +++ b/examples/gemma4/README.md @@ -0,0 +1,60 @@ +# Finetune Google's Gemma 4 with Axolotl + +[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + +```bash +# 26B MoE QLoRA (1x80GB @ ~50 GiB) +axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml + +# 31B Dense QLoRA (1x80GB @ ~44 GiB) +axolotl train examples/gemma4/31b-qlora.yaml + +# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB) +axolotl train examples/gemma4/31b-qlora-flex.yaml +``` + +### MoE Expert Quantization & Expert LoRA (26B-A4B only) + +The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs. + +## Flex Attention + +Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`): + +```yaml +torch_compile: true +flex_attention: true +``` + +This works for both the MoE and Dense model. + +## Limitations + +- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead. +- **LoRA kernels**: Not supported due to KV-sharing layers. +- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone. + +### TIPS + +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested. + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Gemma 4 Blog](https://huggingface.co/blog/gemma4) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index bd92a3630..5f716b779 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"' ) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 6beff8055..650a238ec 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -100,6 +100,27 @@ class AxolotlTrainer( self._signature_columns = None # workaround for pylint super().__init__(*_args, **kwargs) + + # Gemma4 (and similar multimodal models) declare **kwargs in forward() for + # extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as + # "the model handles num_items_in_batch internally" and skips the loss ÷ + # gradient_accumulation_steps normalisation, which inflates the *logged* loss + # (the gradient itself is still correct). Override to False when the model + # doesn't actually consume num_items_in_batch. + if self.model_accepts_loss_kwargs: + model_to_check = self.accelerator.unwrap_model(self.model) + if hasattr(model_to_check, "base_model"): # PEFT wrapper + model_to_check = model_to_check.base_model + if hasattr(model_to_check, "model"): + model_to_check = model_to_check.model + fwd = getattr(model_to_check, "forward", None) + if fwd is not None: + import inspect + + params = inspect.signature(fwd).parameters + if "num_items_in_batch" not in params: + self.model_accepts_loss_kwargs = False + self.train_data_collator = self.data_collator self._stored_metrics = defaultdict( lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) @@ -383,13 +404,27 @@ class AxolotlTrainer( # Gemma4 requires mm_token_type_ids during training (even for text-only). # Inject zeros (= text token type) when not provided by the data collator. + _model_type = getattr(getattr(model, "config", None), "model_type", None) if ( "mm_token_type_ids" not in inputs and "input_ids" in inputs - and getattr(getattr(model, "config", None), "model_type", None) == "gemma4" + and _model_type == "gemma4" ): inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"]) + # Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences + # from position_ids, but only when attention_mask is None. When sample + # packing is active the collator provides an all-ones attention_mask that + # prevents this detection — remove it so the model builds the correct + # per-sequence causal masks. + if ( + self.args.sample_packing + and _model_type in ("gemma4", "gemma3") + and "attention_mask" in inputs + and "position_ids" in inputs + ): + del inputs["attention_mask"] + if self.args.orpo_alpha: return self.orpo_compute_loss( model, diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 220fb4d2b..2ccf11f18 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88" ``` ## Usage @@ -44,6 +44,7 @@ plugins: - gemma3_text - gemma3n - gemma3n_text +- gemma4 - glm - glm4 - glm4_moe diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 758c5406c..b9a59aea9 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`' ) diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 9293c1727..32d236da4 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture: Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is. -**Important limitations:** -- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead. -- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`). - ## Limitations - **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`). diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index 3afeb79c3..f532fde41 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -53,28 +53,6 @@ class KernelsArgs(BaseModel): return data - @model_validator(mode="before") - @classmethod - def warn_sonicmoe_lora_overhead(cls, data): - if data.get("use_sonicmoe") is True and data.get("adapter") in ( - "lora", - "qlora", - ): - lora_target = data.get("lora_target_modules") or [] - lora_linear = data.get("lora_target_linear_modules") or [] - targets = ( - lora_target if isinstance(lora_target, list) else [lora_target] - ) + (lora_linear if isinstance(lora_linear, list) else [lora_linear]) - expert_keywords = ("gate_up_proj", "down_proj", "experts") - if any(kw in t for t in targets for kw in expert_keywords): - LOG.info( - "SonicMoE + LoRA on expert modules uses runtime weight materialization " - "(W_eff = W + scaling*B@A per forward). This has slightly higher overhead " - "than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel." - ) - - return data - @model_validator(mode="before") @classmethod def disable_mlp_kernel(cls, data): diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 1576a10cd..79ac993c1 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -1297,6 +1297,339 @@ def apply_lora_qkv( return Q, K, V +class LoRA_QK(torch.autograd.Function): + """Optimized LoRA QK implementation for models where v_proj is None. + + Used by models like Gemma4 with attention_k_eq_v=True, where key states are + reused as value states. Only Q and K projections are fused; the caller + returns K a second time as V so that autograd accumulates key+value gradients + into a single dK. + + Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation). + """ + + @staticmethod + @torch_amp_custom_fwd + def forward( + ctx: torch.autograd.function.FunctionCtx, + X: torch.Tensor, + X_drop: torch.Tensor | None, + # Q params + q_weight: torch.Tensor, + q_bias: torch.Tensor | None, + q_quant: QuantState | None, + q_A: torch.Tensor | None, + q_B: torch.Tensor | None, + q_scale: float, + q_lora_bias: torch.Tensor | None, + q_magnitude: torch.Tensor | None, + # K params + k_weight: torch.Tensor, + k_bias: torch.Tensor | None, + k_quant: QuantState | None, + k_A: torch.Tensor | None, + k_B: torch.Tensor | None, + k_scale: float, + k_lora_bias: torch.Tensor | None, + k_magnitude: torch.Tensor | None, + # Flags + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + has_dropout = X_drop is not None + has_dora = q_magnitude is not None + + if has_dora: + dtype = X.dtype + X_lora = X_drop if has_dropout else X + + # Compute Q with DoRA + Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None) + Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype) + q_mag_scale = _compute_dora_scale( + q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype + ) + Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora) + if q_bias is not None: + Q = Q + q_bias + + # Compute K with DoRA + K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None) + K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype) + k_mag_scale = _compute_dora_scale( + k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype + ) + K = k_mag_scale.unsqueeze(0) * (K_base + K_lora) + if k_bias is not None: + K = K + k_bias + + Q_combined = Q_base + Q_lora + K_combined = K_base + K_lora + + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + q_A.to(dtype) if q_A is not None else q_A, + q_B.to(dtype) if q_B is not None else q_B, + k_A.to(dtype) if k_A is not None else k_A, + k_B.to(dtype) if k_B is not None else k_B, + q_magnitude, + k_magnitude, + q_mag_scale, + k_mag_scale, + Q_combined, + K_combined, + q_lora_bias, + k_lora_bias, + ) + else: + # Standard LoRA (with optional dropout and bias) + Q = matmul_lora( + X, + q_weight, + q_bias, + q_quant, + q_A, + q_B, + q_scale, + X_drop=X_drop, + lora_bias=q_lora_bias, + ) + K = matmul_lora( + X, + k_weight, + k_bias, + k_quant, + k_A, + k_B, + k_scale, + X_drop=X_drop, + lora_bias=k_lora_bias, + ) + + dtype = X.dtype + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + q_A.to(dtype) if q_A is not None else q_A, + q_B.to(dtype) if q_B is not None else q_B, + k_A.to(dtype) if k_A is not None else k_A, + k_B.to(dtype) if k_B is not None else k_B, + q_lora_bias, + k_lora_bias, + ) + + ctx.scales = (q_scale, k_scale) + ctx.quants = (q_quant, k_quant) + ctx.weights = (q_weight, k_weight) + ctx.inplace = inplace + ctx.has_dropout = has_dropout + ctx.has_dora = has_dora + + return Q, K + + @staticmethod + @torch_amp_custom_bwd + def backward( + ctx: torch.autograd.function.FunctionCtx, + q_grad: torch.Tensor, + k_grad: torch.Tensor, + ): + q_weight, k_weight = ctx.weights + q_quant, k_quant = ctx.quants + q_scale, k_scale = ctx.scales + has_dropout = ctx.has_dropout + has_dora = ctx.has_dora + + if has_dora: + ( + X, + X_lora, + A_q, + B_q, + A_k, + B_k, + q_magnitude, + k_magnitude, + q_mag_scale, + k_mag_scale, + Q_combined, + K_combined, + q_lora_bias, + k_lora_bias, + ) = ctx.saved_tensors + else: + ( + X, + X_lora, + A_q, + B_q, + A_k, + B_k, + q_lora_bias, + k_lora_bias, + ) = ctx.saved_tensors + q_magnitude = k_magnitude = None + q_mag_scale = k_mag_scale = None + Q_combined = K_combined = None + + batch, seq_len = X.shape[:2] + q_grad = q_grad.view(-1, q_grad.shape[-1]) + k_grad = k_grad.reshape(-1, k_grad.shape[-1]) + X = X.view(-1, X.shape[-1]) + X_lora = X_lora.view(-1, X_lora.shape[-1]) + + d_q_mag = d_k_mag = None + d_q_lora_bias = d_k_lora_bias = None + + if has_dora: + Q_combined = Q_combined.view(-1, Q_combined.shape[-1]) + K_combined = K_combined.view(-1, K_combined.shape[-1]) + + d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude + d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude + + q_grad = q_grad * q_mag_scale.unsqueeze(0) + k_grad = k_grad * k_mag_scale.unsqueeze(0) + + # LoRA bias gradients + if q_lora_bias is not None: + d_q_lora_bias = q_scale * q_grad.sum(dim=0) + if k_lora_bias is not None: + d_k_lora_bias = k_scale * k_grad.sum(dim=0) + + X_lora_t = X_lora.t() + + d_A_q = d_B_q = d_A_k = d_B_k = None + grad_B_q = grad_B_k = None + + if A_q is not None and B_q is not None: + grad_B_q = q_grad @ B_q + d_A_q = torch.empty_like(A_q.t()) + d_B_q = torch.empty_like(B_q.t()) + d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0) + d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0) + + if A_k is not None and B_k is not None: + grad_B_k = k_grad @ B_k + d_A_k = torch.empty_like(A_k.t()) + d_B_k = torch.empty_like(B_k.t()) + d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0) + d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0) + + # Base path input gradient + out_buffer = X if ctx.inplace else None + + q_weight_t = dequantize(q_weight, q_quant) + grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer) + del q_weight_t + + k_weight_t = dequantize(k_weight, k_quant) + grad_X.addmm_(k_grad, k_weight_t) + del k_weight_t + + # LoRA path input gradient + if has_dropout: + grad_X_drop = torch.zeros_like(X_lora) + if grad_B_q is not None: + grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale) + if grad_B_k is not None: + grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale) + else: + grad_X_drop = None + if grad_B_q is not None: + grad_X.addmm_(grad_B_q, A_q, alpha=q_scale) + if grad_B_k is not None: + grad_X.addmm_(grad_B_k, A_k, alpha=k_scale) + + if d_A_q is not None: + d_A_q = d_A_q.t() + d_B_q = d_B_q.t() # type: ignore[union-attr] + if d_A_k is not None: + d_A_k = d_A_k.t() + d_B_k = d_B_k.t() # type: ignore[union-attr] + + grad_X = grad_X.view(batch, seq_len, -1) + if grad_X_drop is not None: + grad_X_drop = grad_X_drop.view(batch, seq_len, -1) + + # Return gradients for all forward inputs: + # X, X_drop, + # q: weight, bias, quant, A, B, scale, lora_bias, magnitude + # k: weight, bias, quant, A, B, scale, lora_bias, magnitude + # inplace + return ( + grad_X, + grad_X_drop, + # Q + None, + None, + None, + d_A_q, + d_B_q, + None, + d_q_lora_bias, + d_q_mag, + # K + None, + None, + None, + d_A_k, + d_B_k, + None, + d_k_lora_bias, + d_k_mag, + # inplace + None, + ) + + +def apply_lora_qk( + self, X: torch.Tensor, inplace: bool = True +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies LoRA to compute Query and Key projections for models where v_proj is None. + + When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as + value states. Returns (Q, K, K) — the caller's patched forward will use K as V. + Because K is returned twice, autograd accumulates gradients from both the key and + value paths into dK before calling LoRA_QK.backward. + + Supports bias, dropout, and DoRA. + """ + QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj) + KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj) + + # Apply dropout outside autograd.Function (shared mask for Q, K) + X_drop = _apply_dropout(Qdrop, X, self.training) + + Q, K = LoRA_QK.apply( + X, + X_drop, + # Q + QW, + Qb, + QW_quant, + QA, + QB, + QS, + Qlb, + Qmag, + # K + KW, + Kb, + KW_quant, + KA, + KB, + KS, + Klb, + Kmag, + # Flags + inplace, + ) + + return Q, K, K + + class LoRA_O(torch.autograd.Function): """Optimized LoRA implementation for output projection. diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 2b53b7b2c..6d0bd0fe1 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -67,12 +67,70 @@ def find_all_linear_names(model): return list(lora_module_names) +def _patch_peft_clippable_linear(): + """Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear. + + Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear + that clips activations). PEFT doesn't recognise it as a supported layer type, + so we redirect LoRA injection to the inner ``.linear`` child instead. + """ + try: + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4ClippableLinear as _cls, + ) + except ImportError: + return + + from peft.tuners.lora.model import LoraModel + + if getattr(LoraModel, "_axolotl_clippable_patched", False): + return + _orig = LoraModel._create_and_replace + + def _patched( + self, + peft_config, + adapter_name, + target, + target_name, + parent, + current_key=None, + **kw, + ): + if isinstance(target, _cls): + # Redirect to the inner nn.Linear so PEFT can wrap it normally. + return _orig( + self, + peft_config, + adapter_name, + target.linear, + "linear", + target, + current_key=current_key, + **kw, + ) + return _orig( + self, + peft_config, + adapter_name, + target, + target_name, + parent, + current_key=current_key, + **kw, + ) + + LoraModel._create_and_replace = _patched + LoraModel._axolotl_clippable_patched = True + + def load_lora( model: PreTrainedModel, cfg: DictDefault, inference: bool = False, config_only: bool = False, ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: + _patch_peft_clippable_linear() lora_target_modules = cfg.lora_target_modules or [] lora_target_parameters = cfg.lora_target_parameters or [] @@ -124,6 +182,7 @@ def load_lora( lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, + exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None, bias="none", task_type=task_type, **lora_config_kwargs, diff --git a/src/axolotl/monkeypatch/attention/flash_attn_4.py b/src/axolotl/monkeypatch/attention/flash_attn_4.py index 5ebc93670..b3bde00c1 100644 --- a/src/axolotl/monkeypatch/attention/flash_attn_4.py +++ b/src/axolotl/monkeypatch/attention/flash_attn_4.py @@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None): if getattr(fa_utils._lazy_imports, "_axolotl_patched", False): return + try: + # flash-attn-4>=4.0.0b7 + from flash_attn.cute import flash_attn_with_kvcache + except ImportError: + flash_attn_with_kvcache = None + def _patched_lazy_imports( implementation, attention_wrapper=None, allow_all_kernels=False ): return ( flash_attn_func, flash_attn_varlen_func, + flash_attn_with_kvcache, fa_utils._pad_input, fa_utils._unpad_input, ) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index d569d5925..9cf65286a 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -16,6 +16,7 @@ from axolotl.kernels.lora import ( apply_lora_mlp_geglu, apply_lora_mlp_swiglu, apply_lora_o, + apply_lora_qk, apply_lora_qkv, ) from axolotl.monkeypatch.utils import detab_code @@ -483,18 +484,24 @@ def apply_lora_kernel_patches( if cfg.lora_qkv_kernel: # Query, key, value patching # Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True) - proj_names = ["q_proj", "k_proj", "v_proj"] - layer_modules = [ - getattr(self_attn, name) - for name in proj_names - if getattr(self_attn, name, None) is not None - ] + has_v_proj = getattr(self_attn, "v_proj", None) is not None + proj_names = ( + ["q_proj", "k_proj", "v_proj"] + if has_v_proj + else ["q_proj", "k_proj"] + ) + layer_modules = [getattr(self_attn, name) for name in proj_names] can_patch_qkv = all( hasattr(module, "lora_A") for module in layer_modules ) if can_patch_qkv: - self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) + if has_v_proj: + self_attn.apply_qkv = types.MethodType( + apply_lora_qkv, self_attn + ) + else: + self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn) else: LOG.warning_once( "Cannot patch some attention QKV projections - requires LoRA adapters"