Gemma4 fixes and profiler (#3591)
This commit is contained in:
30
README.md
30
README.md
@@ -86,7 +86,7 @@ Features:
|
|||||||
**Requirements**:
|
**Requirements**:
|
||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python >=3.11 (3.12 recommended)
|
||||||
- PyTorch ≥2.9.1
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
### Google Colab
|
### Google Colab
|
||||||
@@ -95,6 +95,34 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
#### Using uv (recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# install uv if you don't already have it installed
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
source $HOME/.local/bin/env
|
||||||
|
|
||||||
|
# CUDA 12.8.1 tends to have better package compatibility
|
||||||
|
export UV_TORCH_BACKEND=cu128
|
||||||
|
|
||||||
|
# create a new virtual environment
|
||||||
|
uv venv --python 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
uv pip install torch==2.10.0 torchvision
|
||||||
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
|
# recommended - install cut-cross-entropy
|
||||||
|
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||||
|
|
||||||
|
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
||||||
|
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
||||||
|
|
||||||
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
axolotl fetch examples
|
||||||
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
|
```
|
||||||
|
|
||||||
#### Using pip
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -2,6 +2,64 @@
|
|||||||
|
|
||||||
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||||
|
|
||||||
|
## VLM (Vision Language Model) Quick Start
|
||||||
|
|
||||||
|
All VLM configs require these four lines:
|
||||||
|
```yaml
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
Decision tree for VLM config:
|
||||||
|
```text
|
||||||
|
Is the model multimodal (has vision/audio encoder)?
|
||||||
|
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||||
|
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||||
|
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||||
|
└─ NO: Train as a regular text model
|
||||||
|
|
||||||
|
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||||
|
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||||
|
│ Consider ScatterMoE kernels (see Plugins section)
|
||||||
|
└─ NO: Standard LoRA config
|
||||||
|
```
|
||||||
|
|
||||||
|
## Plugins & Optimizations
|
||||||
|
|
||||||
|
### Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
```
|
||||||
|
|
||||||
|
### ScatterMoE Kernels
|
||||||
|
|
||||||
|
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||||
|
|
||||||
## Gemma 4
|
## Gemma 4
|
||||||
|
|
||||||
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||||
@@ -66,6 +124,36 @@ fsdp_config:
|
|||||||
experts_implementation: scattermoe
|
experts_implementation: scattermoe
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### VLM (Vision) Training
|
||||||
|
|
||||||
|
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
freeze_mm_modules: true
|
||||||
|
chat_template: gemma4
|
||||||
|
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||||
|
|
||||||
|
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
### Common issues
|
### Common issues
|
||||||
|
|
||||||
| Symptom | Cause | Fix |
|
| Symptom | Cause | Fix |
|
||||||
|
|||||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||||
|
|
||||||
|
## Profiling
|
||||||
|
|
||||||
|
To profile training and identify optimization opportunities:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||||
|
profiler_steps_start: 3
|
||||||
|
profiler_steps: 5
|
||||||
|
```
|
||||||
|
|
||||||
|
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||||
|
View the Chrome trace at `chrome://tracing`.
|
||||||
|
|
||||||
|
To programmatically inspect the trace:
|
||||||
|
```bash
|
||||||
|
python scripts/analyze_profile.py output_dir/
|
||||||
|
```
|
||||||
|
|
||||||
|
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||||
|
- **Large matmul kernels**: candidates for fusion or quantization
|
||||||
|
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||||
|
- **Small frequent kernels**: candidates for kernel fusion
|
||||||
|
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||||
|
|
||||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||||
|
|
||||||
## File Map
|
## File Map
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ format:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
- [Gemma-4](#sec-gemma-4) *(NEW)*
|
||||||
- [Mllama](#sec-mllama)
|
- [Mllama](#sec-mllama)
|
||||||
- [Llama4](#sec-llama4)
|
- [Llama4](#sec-llama4)
|
||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
|
|||||||
processor_type: VoxtralProcessor
|
processor_type: VoxtralProcessor
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gemma-4 {#sec-gemma-4}
|
||||||
|
|
||||||
|
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
|
||||||
|
|
||||||
|
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
|
||||||
|
:::
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|||||||
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Gemma 4 E2B Vision LoRA
|
||||||
|
#
|
||||||
|
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
|
||||||
|
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
|
||||||
|
|
||||||
|
base_model: google/gemma-4-E2B-it
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
freeze_mm_modules: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Required for vision/multimodal training
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:100]
|
||||||
|
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gemma4-e2b-vision-lora
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
# Target language model only — vision encoder is frozen via freeze_mm_modules
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 10
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Qwen 3.5 35B-A3B MoE Vision LoRA
|
||||||
|
#
|
||||||
|
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
|
||||||
|
# 256 experts, 8 active per token, with early-fusion vision support.
|
||||||
|
|
||||||
|
base_model: Qwen/Qwen3.5-35B-A3B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# Required for vision/multimodal training
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:100]
|
||||||
|
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
sequence_len: 4096
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 10
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
1518
scripts/analyze_profile.py
Normal file
1518
scripts/analyze_profile.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -435,6 +435,23 @@ class AxolotlTrainer(
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Gemma4ForConditionalGeneration computes loss with a manual
|
||||||
|
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
|
||||||
|
# normalization and does redundant attention_mask filtering.
|
||||||
|
# Compute loss externally using the standard loss_function instead.
|
||||||
|
if _model_type == "gemma4" and "labels" in inputs:
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
vocab_size = unwrapped.config.get_text_config().vocab_size
|
||||||
|
loss = unwrapped.loss_function(
|
||||||
|
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
|
||||||
|
)
|
||||||
|
if return_outputs:
|
||||||
|
return loss, outputs
|
||||||
|
return loss
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
|
|||||||
@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
swiglu=cfg.liger_glu_activation,
|
swiglu=cfg.liger_glu_activation,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||||
|
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
|
||||||
|
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
|
||||||
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
|
from transformers.models.gemma4 import modeling_gemma4
|
||||||
|
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
|
||||||
|
|
||||||
|
class _LigerGemma4RMSNorm(LigerRMSNorm):
|
||||||
|
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
|
||||||
|
|
||||||
|
def __new__(cls, dim, eps=1e-6, with_scale=True):
|
||||||
|
if not with_scale:
|
||||||
|
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-6, with_scale=True):
|
||||||
|
if not with_scale:
|
||||||
|
return
|
||||||
|
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
|
||||||
|
super().__init__(
|
||||||
|
dim, eps, offset=0.0, casting_mode="llama", in_place=False
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
|
||||||
|
class _LigerGemma4MLP(LigerGEGLUMLP):
|
||||||
|
def __init__(self, config, layer_idx=None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
|
||||||
|
if cfg.liger_rope:
|
||||||
|
LOG.warning(
|
||||||
|
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
|
||||||
|
)
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
LOG.warning(
|
||||||
|
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Applied Liger kernels for gemma4: "
|
||||||
|
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
|
||||||
|
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
|
||||||
|
)
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
try:
|
try:
|
||||||
from .models.base import patch_lce_forward
|
from .models.base import patch_lce_forward
|
||||||
|
|||||||
@@ -112,6 +112,47 @@ QKV_PATCHES = [
|
|||||||
else:
|
else:
|
||||||
key_states = key_states.view(hidden_shape)
|
key_states = key_states.view(hidden_shape)
|
||||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
|
""".lstrip("\n"),
|
||||||
|
),
|
||||||
|
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
|
||||||
|
# past_key_values.shared_layers, and v_norm added after k_norm.
|
||||||
|
(
|
||||||
|
"""
|
||||||
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||||
|
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||||
|
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
# Device of past layer may be different from current one
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
|
""".lstrip("\n"),
|
||||||
|
"""
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||||
|
query_states = query_states.view(hidden_shape)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||||
|
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||||
|
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
# Device of past layer may be different from current one
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
key_states = key_states.view(hidden_shape)
|
||||||
|
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
""".lstrip("\n"),
|
""".lstrip("\n"),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user