Compare commits
38 Commits
fix/issue-
...
swe-rebenc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d17ed89a3c | ||
|
|
02e4f2350d | ||
|
|
4195605ab2 | ||
|
|
37acb28d02 | ||
|
|
4a5281e61a | ||
|
|
a892d8cce1 | ||
|
|
78de2919a6 | ||
|
|
28283ff373 | ||
|
|
dc16859983 | ||
|
|
d4e9cf2eec | ||
|
|
53391a10d7 | ||
|
|
7617b951a8 | ||
|
|
e993ed5208 | ||
|
|
69f165b39b | ||
|
|
80a97f192b | ||
|
|
323da791eb | ||
|
|
6990478163 | ||
|
|
63a58cfec1 | ||
|
|
3985ec2f67 | ||
|
|
a44edda6d7 | ||
|
|
66c3e5a3fd | ||
|
|
b8358aa5ab | ||
|
|
e079cf16a2 | ||
|
|
e2f69828d2 | ||
|
|
122b50bad6 | ||
|
|
e77a185e86 | ||
|
|
29fa4dedbb | ||
|
|
315cdeede9 | ||
|
|
e7a6a5b529 | ||
|
|
bfb4da1d25 | ||
|
|
4dfa0a59b2 | ||
|
|
4ef608dda3 | ||
|
|
7daf7d96f1 | ||
|
|
7c56809c7f | ||
|
|
149178ddb7 | ||
|
|
dc638e723f | ||
|
|
6f15da4cac | ||
|
|
900eec7988 |
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
@@ -220,6 +220,16 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Verify agent docs are discoverable
|
||||||
|
run: |
|
||||||
|
# Agent docs live in docs/agents/ (source of truth) and are resolved
|
||||||
|
# at runtime from the repo checkout or via `axolotl fetch docs`
|
||||||
|
axolotl agent-docs --list
|
||||||
|
axolotl agent-docs | grep -q "Fine-tuning framework"
|
||||||
|
axolotl agent-docs grpo | grep -q "GRPO"
|
||||||
|
axolotl agent-docs sft | grep -q "SFT"
|
||||||
|
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
|
||||||
|
|
||||||
- name: Show HF cache
|
- name: Show HF cache
|
||||||
run: hf cache ls
|
run: hf cache ls
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
|
|||||||
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
||||||
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
||||||
axolotl fetch examples # Download example configs
|
axolotl fetch examples # Download example configs
|
||||||
|
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
|
||||||
|
axolotl agent-docs grpo # Topic-specific agent reference
|
||||||
|
axolotl config-schema # Dump config JSON schema
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training Methods
|
## Training Methods
|
||||||
@@ -35,6 +38,8 @@ Agent-specific references:
|
|||||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||||
|
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
|
||||||
|
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
|
||||||
|
|
||||||
## Config Pattern
|
## Config Pattern
|
||||||
|
|
||||||
|
|||||||
@@ -3,4 +3,6 @@ include README.md
|
|||||||
include LICENSE
|
include LICENSE
|
||||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||||
|
include AGENTS.md
|
||||||
|
recursive-include docs/agents *.md
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
53
README.md
53
README.md
@@ -86,7 +86,7 @@ Features:
|
|||||||
**Requirements**:
|
**Requirements**:
|
||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python >=3.11 (3.12 recommended)
|
||||||
- PyTorch ≥2.9.1
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
### Google Colab
|
### Google Colab
|
||||||
@@ -95,6 +95,34 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
#### Using uv (recommended)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# install uv if you don't already have it installed
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
source $HOME/.local/bin/env
|
||||||
|
|
||||||
|
# CUDA 12.8.1 tends to have better package compatibility
|
||||||
|
export UV_TORCH_BACKEND=cu128
|
||||||
|
|
||||||
|
# create a new virtual environment
|
||||||
|
uv venv --python 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
uv pip install torch==2.10.0 torchvision
|
||||||
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
|
# recommended - install cut-cross-entropy
|
||||||
|
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||||
|
|
||||||
|
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
||||||
|
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
||||||
|
|
||||||
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
axolotl fetch examples
|
||||||
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
|
```
|
||||||
|
|
||||||
#### Using pip
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -157,6 +185,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
|||||||
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
|
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
|
## AI Agent Support
|
||||||
|
|
||||||
|
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Show overview and available training methods
|
||||||
|
axolotl agent-docs
|
||||||
|
|
||||||
|
# Topic-specific references
|
||||||
|
axolotl agent-docs sft # supervised fine-tuning
|
||||||
|
axolotl agent-docs grpo # GRPO online RL
|
||||||
|
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
|
||||||
|
axolotl agent-docs reward_modelling # outcome and process reward models
|
||||||
|
axolotl agent-docs pretraining # continual pretraining
|
||||||
|
axolotl agent-docs --list # list all topics
|
||||||
|
|
||||||
|
# Dump config schema for programmatic use
|
||||||
|
axolotl config-schema
|
||||||
|
axolotl config-schema --field adapter
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|
||||||
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
||||||
|
|||||||
198
docs/agents/model_architectures.md
Normal file
198
docs/agents/model_architectures.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Model Architectures — Agent Reference
|
||||||
|
|
||||||
|
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||||
|
|
||||||
|
## VLM (Vision Language Model) Quick Start
|
||||||
|
|
||||||
|
All VLM configs require these four lines:
|
||||||
|
```yaml
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
Decision tree for VLM config:
|
||||||
|
```text
|
||||||
|
Is the model multimodal (has vision/audio encoder)?
|
||||||
|
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||||
|
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||||
|
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||||
|
└─ NO: Train as a regular text model
|
||||||
|
|
||||||
|
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||||
|
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||||
|
│ Consider ScatterMoE kernels (see Plugins section)
|
||||||
|
└─ NO: Standard LoRA config
|
||||||
|
```
|
||||||
|
|
||||||
|
## Plugins & Optimizations
|
||||||
|
|
||||||
|
### Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
```
|
||||||
|
|
||||||
|
### ScatterMoE Kernels
|
||||||
|
|
||||||
|
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||||
|
|
||||||
|
## Gemma 4
|
||||||
|
|
||||||
|
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||||
|
|
||||||
|
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
|
||||||
|
|
||||||
|
### Required settings
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Always needed for Gemma4:
|
||||||
|
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
|
||||||
|
|
||||||
|
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Auto-detection
|
||||||
|
|
||||||
|
Axolotl auto-detects Gemma4 and applies:
|
||||||
|
- `use_reentrant: false` for gradient checkpointing
|
||||||
|
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
|
||||||
|
|
||||||
|
### Multi-GPU
|
||||||
|
|
||||||
|
| Strategy | Works? | Notes |
|
||||||
|
|----------|--------|-------|
|
||||||
|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
|
||||||
|
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
|
||||||
|
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
|
||||||
|
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
|
||||||
|
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
|
||||||
|
|
||||||
|
FSDP2 config:
|
||||||
|
```yaml
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
|
||||||
|
```
|
||||||
|
|
||||||
|
### MoE (26B-A4B)
|
||||||
|
|
||||||
|
- `enable_moe_block: true`, 256 experts, top-k routing
|
||||||
|
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
|
||||||
|
- Expert LoRA targets 3D parameter tensors:
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
- ScatterMoE kernel acceleration:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
### VLM (Vision) Training
|
||||||
|
|
||||||
|
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
freeze_mm_modules: true
|
||||||
|
chat_template: gemma4
|
||||||
|
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||||
|
|
||||||
|
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common issues
|
||||||
|
|
||||||
|
| Symptom | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
|
||||||
|
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
|
||||||
|
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
|
||||||
|
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
|
||||||
|
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
|
||||||
|
|
||||||
|
### E2B/E4B dense models
|
||||||
|
|
||||||
|
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
|
||||||
|
|
||||||
|
## Gemma 3
|
||||||
|
|
||||||
|
**Models**: `google/gemma-3-*`
|
||||||
|
|
||||||
|
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
|
||||||
|
- `use_reentrant: false` recommended
|
||||||
|
- Attention mask must be dropped for sample packing (handled automatically)
|
||||||
|
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
|
||||||
|
|
||||||
|
## Qwen 3.5 MoE
|
||||||
|
|
||||||
|
**Models**: `Qwen/Qwen3.5-35B-A3B`
|
||||||
|
|
||||||
|
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
|
||||||
|
- 256 experts, 8 active per token
|
||||||
|
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
|
||||||
|
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
|
||||||
|
```yaml
|
||||||
|
normalize_weight_scales:
|
||||||
|
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||||
|
threshold: 1.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## General MoE Notes
|
||||||
|
|
||||||
|
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
|
||||||
|
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
|
||||||
|
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin
|
||||||
181
docs/agents/new_model_support.md
Normal file
181
docs/agents/new_model_support.md
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# New Model Support — Agent Reference
|
||||||
|
|
||||||
|
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
|
||||||
|
|
||||||
|
## Quick Validation Checklist
|
||||||
|
|
||||||
|
When testing a new model, run through these checks in order:
|
||||||
|
|
||||||
|
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
|
||||||
|
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
|
||||||
|
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
|
||||||
|
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
|
||||||
|
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
|
||||||
|
|
||||||
|
## Loss Debugging
|
||||||
|
|
||||||
|
### Expected initial loss
|
||||||
|
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
|
||||||
|
|
||||||
|
### Direct comparison technique
|
||||||
|
The fastest way to isolate a loss issue — bypass the trainer entirely:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Load model via axolotl's pipeline (applies all patches)
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
|
from axolotl.loaders.model import ModelLoader
|
||||||
|
|
||||||
|
cfg = load_cfg("your_config.yaml")
|
||||||
|
normalize_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||||
|
|
||||||
|
# Forward pass on preprocessed data
|
||||||
|
model.train()
|
||||||
|
out = model(input_ids, labels=labels)
|
||||||
|
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
|
||||||
|
```
|
||||||
|
|
||||||
|
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
|
||||||
|
|
||||||
|
### `model_accepts_loss_kwargs` inflation
|
||||||
|
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
|
||||||
|
|
||||||
|
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
|
||||||
|
|
||||||
|
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
|
||||||
|
|
||||||
|
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
|
||||||
|
|
||||||
|
## Multimodal Models (ForConditionalGeneration)
|
||||||
|
|
||||||
|
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
|
||||||
|
- Gemma3 → `Gemma3ForConditionalGeneration`
|
||||||
|
- Gemma4 → `Gemma4ForConditionalGeneration`
|
||||||
|
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
|
||||||
|
- LLaVA → `LlavaForConditionalGeneration`
|
||||||
|
|
||||||
|
### Why this matters
|
||||||
|
|
||||||
|
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|
||||||
|
|-----------|----------------------|--------------------------------|
|
||||||
|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
|
||||||
|
| PEFT LoRA | ✅ | May fail on custom layer types |
|
||||||
|
| HF Trainer label handling | ✅ | May need extra inputs |
|
||||||
|
|
||||||
|
### Required extra inputs
|
||||||
|
Multimodal models require special inputs during training even for text-only data:
|
||||||
|
|
||||||
|
| Model | Required Input | Value for Text-Only |
|
||||||
|
|-------|---------------|-------------------|
|
||||||
|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
|
||||||
|
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
|
||||||
|
|
||||||
|
### Custom layer types and PEFT
|
||||||
|
Vision towers often use custom module wrappers that PEFT doesn't support:
|
||||||
|
|
||||||
|
| Model | Custom Layer | Wraps | Fix |
|
||||||
|
|-------|-------------|-------|-----|
|
||||||
|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
|
||||||
|
|
||||||
|
## Sample Packing
|
||||||
|
|
||||||
|
### How packed sequence detection works (transformers ≥ 5.x)
|
||||||
|
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# From masking_utils.py:
|
||||||
|
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||||||
|
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||||||
|
```
|
||||||
|
|
||||||
|
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
|
||||||
|
|
||||||
|
### Fix for models using `create_causal_mask_mapping`
|
||||||
|
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In compute_loss():
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and model_type in ("gemma4", "gemma3")
|
||||||
|
and "attention_mask" in inputs
|
||||||
|
and "position_ids" in inputs
|
||||||
|
):
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
|
||||||
|
|
||||||
|
### Models that DON'T need this fix
|
||||||
|
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
|
||||||
|
|
||||||
|
## Attention Backend Selection
|
||||||
|
|
||||||
|
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||||
|
|---------|--------|---------------|---------------|-------|
|
||||||
|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
||||||
|
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||||
|
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
||||||
|
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||||
|
| eager | neither set | None | ✅ | Slowest, always works |
|
||||||
|
|
||||||
|
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||||
|
|
||||||
|
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
|
||||||
|
|
||||||
|
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
|
||||||
|
|
||||||
|
## Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
### How CCE patches work
|
||||||
|
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
|
||||||
|
|
||||||
|
### Adding CCE for a new model
|
||||||
|
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
|
||||||
|
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
|
||||||
|
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
|
||||||
|
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
|
||||||
|
|
||||||
|
### Common CCE pitfall
|
||||||
|
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
|
||||||
|
|
||||||
|
## MoE Models
|
||||||
|
|
||||||
|
### Dense MLP vs MoE experts
|
||||||
|
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
|
||||||
|
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
|
||||||
|
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
|
||||||
|
|
||||||
|
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
|
||||||
|
|
||||||
|
### ScatterMoE kernels
|
||||||
|
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
## Where to Add Model-Specific Fixes
|
||||||
|
|
||||||
|
| What | Where | Example |
|
||||||
|
|------|-------|---------|
|
||||||
|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
|
||||||
|
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
|
||||||
|
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
|
||||||
|
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
|
||||||
|
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
|
||||||
|
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
|
||||||
|
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
|
||||||
|
| Example configs | `examples/<model>/` | Validated YAML |
|
||||||
|
| Config validation | `utils/schemas/validation.py` | Compatibility checks |
|
||||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||||
|
|
||||||
|
## Profiling
|
||||||
|
|
||||||
|
To profile training and identify optimization opportunities:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||||
|
profiler_steps_start: 3
|
||||||
|
profiler_steps: 5
|
||||||
|
```
|
||||||
|
|
||||||
|
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||||
|
View the Chrome trace at `chrome://tracing`.
|
||||||
|
|
||||||
|
To programmatically inspect the trace:
|
||||||
|
```bash
|
||||||
|
python scripts/analyze_profile.py output_dir/
|
||||||
|
```
|
||||||
|
|
||||||
|
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||||
|
- **Large matmul kernels**: candidates for fusion or quantization
|
||||||
|
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||||
|
- **Small frequent kernels**: candidates for kernel fusion
|
||||||
|
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||||
|
|
||||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||||
|
|
||||||
## File Map
|
## File Map
|
||||||
|
|||||||
@@ -108,6 +108,14 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
chat_template_jinja: ./path/to/my_template.jinja2
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||||
:::
|
:::
|
||||||
@@ -294,6 +302,113 @@ datasets:
|
|||||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
#### Content parts with per-part training control
|
||||||
|
|
||||||
|
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me think step by step...", "train": false},
|
||||||
|
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The configuration is the same as standard `chat_template` — no extra fields needed:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
roles_to_train: ["assistant"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Each content part supports:
|
||||||
|
|
||||||
|
- `type`: `"text"` (required)
|
||||||
|
- `text`: the text value (also accepts `content` or `value` as the key)
|
||||||
|
- `train`: `true`/`false` (optional) — whether to train on this part
|
||||||
|
- `weight`: `0`/`1` (optional) — alternative to `train`
|
||||||
|
|
||||||
|
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
|
||||||
|
|
||||||
|
::: {.callout-warning title="Whitespace at part boundaries"}
|
||||||
|
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
|
||||||
|
|
||||||
|
**Split BEFORE spaces** (space goes with the next part):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Let me think...", "train": false},
|
||||||
|
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Let me think... ", "train": false},
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
|
||||||
|
|
||||||
|
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Thinking:\n", "train": false},
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
|
||||||
|
:::
|
||||||
|
|
||||||
|
##### Per-part training on reasoning_content
|
||||||
|
|
||||||
|
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": [
|
||||||
|
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
|
||||||
|
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
|
||||||
|
],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
|
||||||
|
:::
|
||||||
|
|
||||||
|
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
|
||||||
|
|
||||||
|
|
||||||
#### Reasoning split
|
#### Reasoning split
|
||||||
|
|
||||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ format:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
- [Gemma-4](#sec-gemma-4) *(NEW)*
|
||||||
- [Mllama](#sec-mllama)
|
- [Mllama](#sec-mllama)
|
||||||
- [Llama4](#sec-llama4)
|
- [Llama4](#sec-llama4)
|
||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
|
|||||||
processor_type: VoxtralProcessor
|
processor_type: VoxtralProcessor
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gemma-4 {#sec-gemma-4}
|
||||||
|
|
||||||
|
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
|
||||||
|
|
||||||
|
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
|
||||||
|
:::
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
# Freeze vision tower
|
# Freeze vision tower
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
- ^model\.language_model\..*
|
- ^model.language_model.*
|
||||||
- ^lm_head\..*
|
- ^lm_head.*
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
# Freeze vision tower
|
# Freeze vision tower
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
- ^model\.language_model\..*
|
- ^model.language_model.*
|
||||||
- ^lm_head\..*
|
- ^lm_head.*
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
# Freeze vision tower
|
# Freeze vision tower
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
- ^model\.language_model\..*
|
- ^model.language_model.*
|
||||||
- ^lm_head\..*
|
- ^lm_head.*
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1,19 +1,12 @@
|
|||||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||||
#
|
#
|
||||||
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
|
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
|
||||||
|
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
|
||||||
#
|
#
|
||||||
# Key notes:
|
# Key notes:
|
||||||
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
|
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
|
||||||
# Use sdp_attention instead.
|
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
|
||||||
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
|
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
|
||||||
# LoRA to the text backbone via lora_target_linear_modules regex.
|
|
||||||
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
|
|
||||||
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
|
|
||||||
# via the transformers ExpertsInterface.
|
|
||||||
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
|
|
||||||
# (no `mlp.` prefix, unlike Qwen/Mixtral).
|
|
||||||
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
|
|
||||||
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
|
|
||||||
|
|
||||||
base_model: google/gemma-4-26B-A4B
|
base_model: google/gemma-4-26B-A4B
|
||||||
|
|
||||||
@@ -24,7 +17,7 @@ plugins:
|
|||||||
use_kernels: true
|
use_kernels: true
|
||||||
use_scattermoe: true
|
use_scattermoe: true
|
||||||
experts_implementation: scattermoe
|
experts_implementation: scattermoe
|
||||||
torch_compile: false
|
torch_compile: true
|
||||||
liger_layer_norm: true
|
liger_layer_norm: true
|
||||||
liger_rope: true
|
liger_rope: true
|
||||||
liger_rms_norm: true
|
liger_rms_norm: true
|
||||||
@@ -54,12 +47,9 @@ lora_r: 16
|
|||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0
|
lora_dropout: 0
|
||||||
|
|
||||||
# Restrict LoRA to text backbone only (skip vision/audio encoders).
|
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||||
# lora_target_modules is intentionally empty — all module targeting is done
|
# using regex to match only the text decoder attention projections.
|
||||||
# via regex in lora_target_linear_modules below.
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
lora_target_modules: []
|
|
||||||
lora_target_linear_modules:
|
|
||||||
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
|
|
||||||
|
|
||||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||||
lora_target_parameters:
|
lora_target_parameters:
|
||||||
@@ -73,7 +63,7 @@ lora_o_kernel: false
|
|||||||
bnb_config_kwargs:
|
bnb_config_kwargs:
|
||||||
bnb_4bit_use_double_quant: true
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
wandb_project: gemma4-qlora
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
@@ -93,8 +83,7 @@ gradient_checkpointing: true
|
|||||||
activation_offloading: true
|
activation_offloading: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
|
# FA2 not supported
|
||||||
flash_attention: false
|
|
||||||
sdp_attention: true
|
sdp_attention: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: google/gemma-4-31B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
torch_compile: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm_gated: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:10%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/gemma4-31b-qlora-flex
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
|
||||||
|
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
bnb_config_kwargs:
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# FA not supported
|
||||||
|
flex_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
69
examples/gemma4/31b-qlora.yaml
Normal file
69
examples/gemma4/31b-qlora.yaml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
base_model: google/gemma-4-31B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
torch_compile: false
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm_gated: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:10%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/gemma4-31b-qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
|
||||||
|
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||||
|
# using regex to match only the text decoder attention projections.
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
bnb_config_kwargs:
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
# FA not supported
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
60
examples/gemma4/README.md
Normal file
60
examples/gemma4/README.md
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# Finetune Google's Gemma 4 with Axolotl
|
||||||
|
|
||||||
|
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
|
||||||
|
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
|
||||||
|
|
||||||
|
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
|
||||||
|
axolotl train examples/gemma4/31b-qlora.yaml
|
||||||
|
|
||||||
|
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
|
||||||
|
axolotl train examples/gemma4/31b-qlora-flex.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
|
||||||
|
|
||||||
|
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||||
|
|
||||||
|
## Flex Attention
|
||||||
|
|
||||||
|
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
torch_compile: true
|
||||||
|
flex_attention: true
|
||||||
|
```
|
||||||
|
|
||||||
|
This works for both the MoE and Dense model.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
|
||||||
|
- **LoRA kernels**: Not supported due to KV-sharing layers.
|
||||||
|
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Gemma 4 E2B Vision LoRA
|
||||||
|
#
|
||||||
|
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
|
||||||
|
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
|
||||||
|
|
||||||
|
base_model: google/gemma-4-E2B-it
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
freeze_mm_modules: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Required for vision/multimodal training
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:100]
|
||||||
|
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gemma4-e2b-vision-lora
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
# Target language model only — vision encoder is frozen via freeze_mm_modules
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 10
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
@@ -1,5 +1,15 @@
|
|||||||
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm_gated: true
|
||||||
|
|
||||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||||
lora_mlp_kernel: false
|
lora_mlp_kernel: false
|
||||||
lora_qkv_kernel: false
|
lora_qkv_kernel: false
|
||||||
@@ -22,8 +32,6 @@ dataset_prepared_path: last_run_prepared
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
use_cut_cross_entropy: true
|
|
||||||
|
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
quantize_moe_experts: true
|
quantize_moe_experts: true
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -31,16 +39,16 @@ lora_r: 16
|
|||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.0
|
lora_dropout: 0.0
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
# Attention projection layers (present in ~12 attention layers out of 88)
|
|
||||||
- q_proj
|
- q_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
# To also train MoE expert weights, add them via lora_target_parameters
|
|
||||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
# To also train MoE expert weights, add them via lora_target_parameters
|
||||||
# lora_target_parameters:
|
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||||
# - up_proj
|
# lora_target_parameters:
|
||||||
# - down_proj
|
# - up_proj
|
||||||
|
# - down_proj
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -1,6 +1,16 @@
|
|||||||
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
||||||
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm_gated: true
|
||||||
|
|
||||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||||
lora_mlp_kernel: false
|
lora_mlp_kernel: false
|
||||||
lora_qkv_kernel: false
|
lora_qkv_kernel: false
|
||||||
@@ -23,8 +33,6 @@ dataset_prepared_path: last_run_prepared
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
use_cut_cross_entropy: true
|
|
||||||
|
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
quantize_moe_experts: true
|
quantize_moe_experts: true
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -36,11 +44,12 @@ lora_target_modules:
|
|||||||
- k_proj
|
- k_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
# To also train MoE expert weights, add them via lora_target_parameters
|
|
||||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
# To also train MoE expert weights, add them via lora_target_parameters
|
||||||
# lora_target_parameters:
|
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||||
# - up_proj
|
# lora_target_parameters:
|
||||||
# - down_proj
|
# - up_proj
|
||||||
|
# - down_proj
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ sample_packing: true
|
|||||||
|
|
||||||
# Freeze vision encoder
|
# Freeze vision encoder
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
- model\.language_model\..*
|
- model.language_model.*
|
||||||
- lm_head\..*
|
- lm_head.*
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Qwen 3.5 35B-A3B MoE Vision LoRA
|
||||||
|
#
|
||||||
|
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
|
||||||
|
# 256 experts, 8 active per token, with early-fusion vision support.
|
||||||
|
|
||||||
|
base_model: Qwen/Qwen3.5-35B-A3B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# Required for vision/multimodal training
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:100]
|
||||||
|
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
sequence_len: 4096
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 10
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
@@ -10,15 +10,15 @@ liger-kernel==0.7.0
|
|||||||
|
|
||||||
packaging==26.0
|
packaging==26.0
|
||||||
huggingface_hub>=1.1.7
|
huggingface_hub>=1.1.7
|
||||||
peft>=0.18.1
|
peft>=0.19.0,<0.20.0
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==5.5.0
|
transformers==5.5.4
|
||||||
accelerate==1.13.0
|
accelerate==1.13.0
|
||||||
datasets==4.5.0
|
datasets>=4.8.4,<4.9.0
|
||||||
deepspeed>=0.18.6,<0.19.0
|
deepspeed>=0.18.6,<0.19.0
|
||||||
trl==0.29.0
|
trl==1.1.0
|
||||||
hf_xet==1.3.2
|
hf_xet==1.4.3
|
||||||
kernels==0.12.2
|
kernels==0.13.0
|
||||||
|
|
||||||
fla-core==0.4.1
|
fla-core==0.4.1
|
||||||
flash-linear-attention==0.4.1
|
flash-linear-attention==0.4.1
|
||||||
|
|||||||
1518
scripts/analyze_profile.py
Normal file
1518
scripts/analyze_profile.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
|
||||||
)
|
)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
|
|||||||
]
|
]
|
||||||
if not install_xformers:
|
if not install_xformers:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
extras_require_map["vllm"] = ["vllm>=0.17.1"]
|
extras_require_map["vllm"] = ["vllm>=0.19.0"]
|
||||||
elif (major, minor) >= (2, 9):
|
elif (major, minor) >= (2, 9):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
|
|||||||
108
src/axolotl/cli/agent_docs/__init__.py
Normal file
108
src/axolotl/cli/agent_docs/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Bundled agent documentation for axolotl.
|
||||||
|
|
||||||
|
These docs are optimized for consumption by AI coding agents.
|
||||||
|
The source of truth is docs/agents/*.md and AGENTS.md in the repo root.
|
||||||
|
This module resolves those paths at runtime — no files are duplicated
|
||||||
|
into the package.
|
||||||
|
|
||||||
|
For pip-only installs (no repo checkout), run `axolotl fetch docs` first
|
||||||
|
to download the docs locally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Topic name -> (filename in docs/agents/, fallback filename for AGENTS.md)
|
||||||
|
TOPICS = {
|
||||||
|
"overview": "AGENTS.md",
|
||||||
|
"sft": "docs/agents/sft.md",
|
||||||
|
"grpo": "docs/agents/grpo.md",
|
||||||
|
"preference_tuning": "docs/agents/preference_tuning.md",
|
||||||
|
"reward_modelling": "docs/agents/reward_modelling.md",
|
||||||
|
"pretraining": "docs/agents/pretraining.md",
|
||||||
|
"model_architectures": "docs/agents/model_architectures.md",
|
||||||
|
"new_model_support": "docs/agents/new_model_support.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path | None:
|
||||||
|
"""Walk up from this file to find the repo root (contains AGENTS.md)."""
|
||||||
|
# In an editable install or repo checkout, walk up from
|
||||||
|
# src/axolotl/cli/agent_docs/ to find the repo root
|
||||||
|
current = Path(__file__).resolve().parent
|
||||||
|
while current != current.parent:
|
||||||
|
if (current / "AGENTS.md").exists() and (current / "docs" / "agents").is_dir():
|
||||||
|
return current
|
||||||
|
current = current.parent
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_docs_dir() -> Path | None:
|
||||||
|
"""Find a fetched docs directory (from `axolotl fetch docs`)."""
|
||||||
|
# axolotl fetch docs --dest defaults to ./docs/ in cwd
|
||||||
|
cwd_docs = Path.cwd() / "docs" / "agents"
|
||||||
|
if cwd_docs.is_dir():
|
||||||
|
return Path.cwd()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(topic: str) -> Path:
|
||||||
|
"""Resolve a topic name to the actual file path."""
|
||||||
|
if topic not in TOPICS:
|
||||||
|
available = ", ".join(sorted(TOPICS.keys()))
|
||||||
|
raise FileNotFoundError(f"Unknown topic: {topic!r}. Available: {available}")
|
||||||
|
|
||||||
|
relative_path = TOPICS[topic]
|
||||||
|
|
||||||
|
# Try repo root first (editable install / repo checkout)
|
||||||
|
repo_root = _find_repo_root()
|
||||||
|
if repo_root:
|
||||||
|
candidate = repo_root / relative_path
|
||||||
|
if candidate.exists():
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# Try cwd (fetched docs via `axolotl fetch docs`)
|
||||||
|
docs_root = _find_docs_dir()
|
||||||
|
if docs_root:
|
||||||
|
candidate = docs_root / relative_path
|
||||||
|
if candidate.exists():
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# Also check cwd directly for AGENTS.md
|
||||||
|
if topic == "overview":
|
||||||
|
cwd_agents = Path.cwd() / "AGENTS.md"
|
||||||
|
if cwd_agents.exists():
|
||||||
|
return cwd_agents
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not find {relative_path!r}. "
|
||||||
|
f"If you installed axolotl via pip, run `axolotl fetch docs` first "
|
||||||
|
f"to download the documentation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc(topic: str = "overview") -> str:
|
||||||
|
"""Return the content of an agent doc by topic name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: One of the keys in TOPICS, or "overview" (default).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The markdown content of the doc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the topic can't be found.
|
||||||
|
"""
|
||||||
|
return _resolve_path(topic).read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def list_topics() -> dict[str, str]:
|
||||||
|
"""Return a dict of topic name -> first line (title) of each doc."""
|
||||||
|
result = {}
|
||||||
|
for topic in sorted(TOPICS.keys()):
|
||||||
|
try:
|
||||||
|
path = _resolve_path(topic)
|
||||||
|
first_line = path.read_text().split("\n", 1)[0].lstrip("# ").strip()
|
||||||
|
result[topic] = first_line
|
||||||
|
except FileNotFoundError:
|
||||||
|
result[topic] = "(not found — run `axolotl fetch docs`)"
|
||||||
|
return result
|
||||||
@@ -294,7 +294,9 @@ def merge_lora(config: str, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
@click.argument(
|
||||||
|
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
|
||||||
|
)
|
||||||
@click.option("--dest", help="Destination directory")
|
@click.option("--dest", help="Destination directory")
|
||||||
def fetch(directory: str, dest: Optional[str]):
|
def fetch(directory: str, dest: Optional[str]):
|
||||||
"""
|
"""
|
||||||
@@ -303,9 +305,10 @@ def fetch(directory: str, dest: Optional[str]):
|
|||||||
Available directories:
|
Available directories:
|
||||||
- examples: Example configuration files
|
- examples: Example configuration files
|
||||||
- deepspeed_configs: DeepSpeed configuration files
|
- deepspeed_configs: DeepSpeed configuration files
|
||||||
|
- docs: Full documentation (Quarto markdown files)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
directory: One of `examples`, `deepspeed_configs`.
|
directory: One of `examples`, `deepspeed_configs`, `docs`.
|
||||||
dest: Optional destination directory.
|
dest: Optional destination directory.
|
||||||
"""
|
"""
|
||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
@@ -340,6 +343,112 @@ def delinearize_llama4(model: str, output: str):
|
|||||||
do_delinearize_llama4(model, output)
|
do_delinearize_llama4(model, output)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("agent-docs")
|
||||||
|
@click.argument("topic", required=False, default=None)
|
||||||
|
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
|
||||||
|
def agent_docs(topic: Optional[str], list_topics: bool):
|
||||||
|
"""Show agent-optimized documentation.
|
||||||
|
|
||||||
|
Prints reference docs designed for AI coding agents.
|
||||||
|
These docs are bundled with the package — no network access needed.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
axolotl agent-docs # overview (start here)
|
||||||
|
axolotl agent-docs grpo # GRPO reference
|
||||||
|
axolotl agent-docs sft # SFT reference
|
||||||
|
axolotl agent-docs --list # list all topics
|
||||||
|
"""
|
||||||
|
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
|
||||||
|
|
||||||
|
if list_topics:
|
||||||
|
for name, title in _list_topics().items():
|
||||||
|
click.echo(f" {name:25s} {title}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if topic is None:
|
||||||
|
topic = "overview"
|
||||||
|
|
||||||
|
try:
|
||||||
|
click.echo(get_doc(topic))
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
raise click.BadParameter(str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command("config-schema")
|
||||||
|
@click.option(
|
||||||
|
"--format",
|
||||||
|
"output_format",
|
||||||
|
type=click.Choice(["json", "yaml"]),
|
||||||
|
default="json",
|
||||||
|
help="Output format (default: json)",
|
||||||
|
)
|
||||||
|
@click.option("--field", help="Show schema for a specific field only")
|
||||||
|
def config_schema(output_format: str, field: Optional[str]):
|
||||||
|
"""Dump the full config JSON schema.
|
||||||
|
|
||||||
|
Useful for AI agents and tooling to discover all available config options,
|
||||||
|
their types, defaults, and descriptions.
|
||||||
|
|
||||||
|
\b
|
||||||
|
Examples:
|
||||||
|
axolotl config-schema # full JSON schema
|
||||||
|
axolotl config-schema --format yaml # YAML format
|
||||||
|
axolotl config-schema --field adapter # single field
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = AxolotlInputConfig.model_json_schema()
|
||||||
|
except (TypeError, ValueError, AttributeError) as exc:
|
||||||
|
# Fallback: dump field names, types, and defaults when full schema
|
||||||
|
# generation fails (e.g. torch.dtype not JSON-serializable)
|
||||||
|
LOG.warning(
|
||||||
|
"Full JSON schema generation failed, using simplified fallback: %s", exc
|
||||||
|
)
|
||||||
|
fields = {}
|
||||||
|
for name, field_info in AxolotlInputConfig.model_fields.items():
|
||||||
|
entry = {}
|
||||||
|
if field_info.description:
|
||||||
|
entry["description"] = field_info.description
|
||||||
|
if field_info.default is not None:
|
||||||
|
try:
|
||||||
|
json.dumps(field_info.default)
|
||||||
|
entry["default"] = field_info.default
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
entry["default"] = str(field_info.default)
|
||||||
|
annotation = field_info.annotation
|
||||||
|
if annotation is not None:
|
||||||
|
entry["type"] = str(annotation)
|
||||||
|
fields[name] = entry
|
||||||
|
schema = {
|
||||||
|
"properties": fields,
|
||||||
|
"_note": "simplified schema (full generation failed)",
|
||||||
|
}
|
||||||
|
|
||||||
|
if field:
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
if field not in props:
|
||||||
|
# Try case-insensitive match
|
||||||
|
matches = [k for k in props if k.lower() == field.lower()]
|
||||||
|
if matches:
|
||||||
|
field = matches[0]
|
||||||
|
else:
|
||||||
|
raise click.BadParameter(
|
||||||
|
f"Unknown field: {field!r}. "
|
||||||
|
f"Omit --field to dump the full schema, "
|
||||||
|
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
|
||||||
|
)
|
||||||
|
schema = {field: props[field]}
|
||||||
|
|
||||||
|
if output_format == "yaml":
|
||||||
|
import yaml # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
|
||||||
|
else:
|
||||||
|
click.echo(json.dumps(schema, indent=2))
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
|||||||
simulate_nf4_experts=simulate_nf4_experts,
|
simulate_nf4_experts=simulate_nf4_experts,
|
||||||
nf4_blocksize=nf4_blocksize,
|
nf4_blocksize=nf4_blocksize,
|
||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
|
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||||
|
|||||||
@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_layer_type_map(
|
||||||
|
base_model_path: Path, trust_remote_code: bool = False
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Build a map of module_name -> layer_type using a meta-device model.
|
||||||
|
|
||||||
|
Instantiates the model architecture on the meta device (zero memory)
|
||||||
|
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
|
||||||
|
This avoids relying on weight tensor ndim heuristics.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
config_path = base_model_path / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path) as f:
|
||||||
|
model_config = _json.load(f)
|
||||||
|
except (OSError, _json.JSONDecodeError):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
architectures = model_config.get("architectures", [])
|
||||||
|
if not architectures:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
str(base_model_path), trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Could not load config for layer type introspection")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Determine the right Auto class from architectures
|
||||||
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_classes = [AutoModelForCausalLM, AutoModel]
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
|
auto_classes.insert(0, AutoModelForImageTextToText)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model = None
|
||||||
|
for auto_cls in auto_classes:
|
||||||
|
try:
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = auto_cls.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOG.debug(
|
||||||
|
"Could not instantiate meta model with %s, trying next",
|
||||||
|
auto_cls.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
LOG.debug("Could not instantiate meta model for layer type introspection")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
layer_types = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Conv3d):
|
||||||
|
layer_types[name] = "Conv3d"
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
layer_types[name] = "Conv2d"
|
||||||
|
elif isinstance(module, nn.Conv1d):
|
||||||
|
layer_types[name] = "Conv1d"
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
layer_types[name] = "Linear"
|
||||||
|
|
||||||
|
del model
|
||||||
|
LOG.debug(
|
||||||
|
f"Layer type map: {len(layer_types)} modules "
|
||||||
|
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
|
||||||
|
)
|
||||||
|
return layer_types
|
||||||
|
|
||||||
|
|
||||||
def _simulate_nf4_roundtrip(
|
def _simulate_nf4_roundtrip(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
blocksize: Optional[int] = None,
|
blocksize: Optional[int] = None,
|
||||||
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
|
|||||||
adapter_name: str = "default",
|
adapter_name: str = "default",
|
||||||
is_param_wrapper: bool = False,
|
is_param_wrapper: bool = False,
|
||||||
magnitude: Optional[torch.Tensor] = None,
|
magnitude: Optional[torch.Tensor] = None,
|
||||||
|
layer_type: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||||
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
|
|||||||
out_features = lora_b.shape[0]
|
out_features = lora_b.shape[0]
|
||||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||||
|
|
||||||
if is_param_wrapper:
|
if is_param_wrapper:
|
||||||
from peft.tuners.lora.layer import ParamWrapper
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
@@ -227,18 +315,106 @@ def _build_peft_layer_and_get_delta(
|
|||||||
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
|
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so
|
||||||
|
# build a minimal config with only the fields it accepts.
|
||||||
|
pw_config = LoraConfig(
|
||||||
|
r=r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=0.0,
|
||||||
|
fan_in_fan_out=False,
|
||||||
|
use_rslora=use_rslora,
|
||||||
|
use_dora=False,
|
||||||
|
lora_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
layer = ParamWrapper(
|
layer = ParamWrapper(
|
||||||
fake,
|
fake,
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
parameter_name="weight",
|
parameter_name="weight",
|
||||||
|
config=pw_config,
|
||||||
r=r,
|
r=r,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
use_rslora=use_rslora,
|
|
||||||
)
|
)
|
||||||
layer.lora_A[adapter_name].weight.data = lora_a
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
layer.lora_B[adapter_name].weight.data = lora_b
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
|
return layer.get_delta_weight(adapter_name)
|
||||||
|
elif (
|
||||||
|
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||||
|
):
|
||||||
|
# Conv layer detected via model introspection (or ndim fallback)
|
||||||
|
|
||||||
|
from peft.tuners.lora import layer as peft_lora_layer
|
||||||
|
|
||||||
|
# Determine conv type from layer_type map or fall back to ndim
|
||||||
|
if layer_type and "Conv" in layer_type:
|
||||||
|
conv_type: str = layer_type
|
||||||
|
else:
|
||||||
|
ndim = lora_a.ndim
|
||||||
|
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
|
||||||
|
if ndim not in _conv_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
|
||||||
|
)
|
||||||
|
conv_type = _conv_map[ndim]
|
||||||
|
LOG.warning(
|
||||||
|
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
|
||||||
|
f"Consider providing layer_type from meta-device introspection."
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
|
||||||
|
ConvCls = conv_cls_map[conv_type]
|
||||||
|
PeftConvCls = getattr(peft_lora_layer, conv_type)
|
||||||
|
|
||||||
|
# Reconstruct conv parameters from base tensor and lora_a shapes
|
||||||
|
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
|
||||||
|
# lora_a: [r, in_channels/groups, *kernel_size]
|
||||||
|
# lora_b: [out_channels, r, *ones]
|
||||||
|
out_channels = base_tensor.shape[0]
|
||||||
|
in_channels = base_tensor.shape[1]
|
||||||
|
kernel_size = tuple(base_tensor.shape[2:])
|
||||||
|
stride = (1,) * (base_tensor.ndim - 2)
|
||||||
|
padding = (0,) * (base_tensor.ndim - 2)
|
||||||
|
|
||||||
|
base_layer = ConvCls(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
base_layer.weight.data = base_tensor.clone()
|
||||||
|
|
||||||
|
conv_config = LoraConfig(
|
||||||
|
r=r_total,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
use_rslora=use_rslora,
|
||||||
|
use_dora=use_dora,
|
||||||
|
)
|
||||||
|
layer = PeftConvCls(
|
||||||
|
base_layer,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
config=conv_config,
|
||||||
|
r=r_total,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
)
|
||||||
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
|
|
||||||
|
if use_dora:
|
||||||
|
if magnitude is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"DoRA merge requires a magnitude vector but none was found "
|
||||||
|
f"for conv layer (adapter={adapter_name}). Check that the "
|
||||||
|
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||||
|
)
|
||||||
|
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||||
|
mag_layer.weight = nn.Parameter(magnitude)
|
||||||
|
layer.merge(adapter_names=[adapter_name])
|
||||||
|
return base_layer.weight.data - base_tensor
|
||||||
|
|
||||||
return layer.get_delta_weight(adapter_name)
|
return layer.get_delta_weight(adapter_name)
|
||||||
else:
|
else:
|
||||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||||
@@ -251,15 +427,20 @@ def _build_peft_layer_and_get_delta(
|
|||||||
or lora_config_dict.get("lora_fan_in_fan_out", False)
|
or lora_config_dict.get("lora_fan_in_fan_out", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
layer = LoraLinear(
|
linear_config = LoraConfig(
|
||||||
base_layer,
|
|
||||||
adapter_name=adapter_name,
|
|
||||||
r=r_total,
|
r=r_total,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
fan_in_fan_out=fan_in_fan_out,
|
fan_in_fan_out=fan_in_fan_out,
|
||||||
use_rslora=use_rslora,
|
use_rslora=use_rslora,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
)
|
)
|
||||||
|
layer = LoraLinear(
|
||||||
|
base_layer,
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
config=linear_config,
|
||||||
|
r=r_total,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
)
|
||||||
layer.lora_A[adapter_name].weight.data = lora_a
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
layer.lora_B[adapter_name].weight.data = lora_b
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
|
|
||||||
@@ -267,6 +448,12 @@ def _build_peft_layer_and_get_delta(
|
|||||||
# DoRA merges magnitude normalization into the weight directly.
|
# DoRA merges magnitude normalization into the weight directly.
|
||||||
# Use PEFT's merge() which handles DoRA internally, then
|
# Use PEFT's merge() which handles DoRA internally, then
|
||||||
# compute the delta as merged_weight - original_weight.
|
# compute the delta as merged_weight - original_weight.
|
||||||
|
if magnitude is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"DoRA merge requires a magnitude vector but none was found "
|
||||||
|
f"for linear layer (adapter={adapter_name}). Check that the "
|
||||||
|
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||||
|
)
|
||||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||||
mag_layer.weight = nn.Parameter(magnitude)
|
mag_layer.weight = nn.Parameter(magnitude)
|
||||||
layer.merge(adapter_names=[adapter_name])
|
layer.merge(adapter_names=[adapter_name])
|
||||||
@@ -382,6 +569,7 @@ def _merge_tensor_with_lora(
|
|||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
weight_renamings: Optional[Dict[str, str]] = None,
|
weight_renamings: Optional[Dict[str, str]] = None,
|
||||||
|
layer_type_map: Optional[Dict[str, str]] = None,
|
||||||
) -> tuple[torch.Tensor, bool]:
|
) -> tuple[torch.Tensor, bool]:
|
||||||
"""
|
"""
|
||||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||||
@@ -426,12 +614,30 @@ def _merge_tensor_with_lora(
|
|||||||
if use_dora
|
if use_dora
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Look up layer type from meta-device model introspection
|
||||||
|
_layer_type = None
|
||||||
|
if layer_type_map:
|
||||||
|
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
|
||||||
|
_layer_type = layer_type_map.get(mod_path)
|
||||||
|
# Try common prefix variations (e.g. with/without "model." prefix)
|
||||||
|
if _layer_type is None:
|
||||||
|
for prefix in [
|
||||||
|
"model.",
|
||||||
|
"model.language_model.",
|
||||||
|
"model.language_model.model.",
|
||||||
|
]:
|
||||||
|
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||||
|
if _layer_type:
|
||||||
|
break
|
||||||
|
|
||||||
delta = _build_peft_layer_and_get_delta(
|
delta = _build_peft_layer_and_get_delta(
|
||||||
lora_a.to(device),
|
lora_a.to(device),
|
||||||
lora_b.to(device),
|
lora_b.to(device),
|
||||||
lora_config_dict,
|
lora_config_dict,
|
||||||
tensor.to(device),
|
tensor.to(device),
|
||||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||||
|
layer_type=_layer_type,
|
||||||
)
|
)
|
||||||
merged_tensor = (
|
merged_tensor = (
|
||||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||||
@@ -556,6 +762,7 @@ def _fuse_and_unfuse_with_merge(
|
|||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
weight_renamings: Optional[Dict[str, str]] = None,
|
weight_renamings: Optional[Dict[str, str]] = None,
|
||||||
|
layer_type_map: Optional[Dict[str, str]] = None,
|
||||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||||
"""
|
"""
|
||||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||||
@@ -696,12 +903,32 @@ def _fuse_and_unfuse_with_merge(
|
|||||||
if use_dora
|
if use_dora
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
# Look up layer type for the fused key
|
||||||
|
_layer_type = None
|
||||||
|
if layer_type_map:
|
||||||
|
mod_path = (
|
||||||
|
fused_key.rsplit(".weight", 1)[0]
|
||||||
|
if fused_key.endswith(".weight")
|
||||||
|
else fused_key
|
||||||
|
)
|
||||||
|
_layer_type = layer_type_map.get(mod_path)
|
||||||
|
if _layer_type is None:
|
||||||
|
for prefix in [
|
||||||
|
"model.",
|
||||||
|
"model.language_model.",
|
||||||
|
"model.language_model.model.",
|
||||||
|
]:
|
||||||
|
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||||
|
if _layer_type:
|
||||||
|
break
|
||||||
|
|
||||||
delta = _build_peft_layer_and_get_delta(
|
delta = _build_peft_layer_and_get_delta(
|
||||||
lora_a.to(device),
|
lora_a.to(device),
|
||||||
lora_b.to(device),
|
lora_b.to(device),
|
||||||
lora_config_dict,
|
lora_config_dict,
|
||||||
fused_tensor.to(device),
|
fused_tensor.to(device),
|
||||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||||
|
layer_type=_layer_type,
|
||||||
)
|
)
|
||||||
fused_tensor = (
|
fused_tensor = (
|
||||||
(
|
(
|
||||||
@@ -740,6 +967,7 @@ def merge_lora_sharded_efficient(
|
|||||||
simulate_nf4_experts: bool = False,
|
simulate_nf4_experts: bool = False,
|
||||||
nf4_blocksize: Optional[int] = None,
|
nf4_blocksize: Optional[int] = None,
|
||||||
nf4_double_quant: bool = True,
|
nf4_double_quant: bool = True,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Memory-efficient LoRA merging that processes shards individually
|
Memory-efficient LoRA merging that processes shards individually
|
||||||
@@ -750,6 +978,8 @@ def merge_lora_sharded_efficient(
|
|||||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||||
(for quantize_moe_experts). Expert tensors are identified by having
|
(for quantize_moe_experts). Expert tensors are identified by having
|
||||||
"expert" in the key name and ndim >= 3.
|
"expert" in the key name and ndim >= 3.
|
||||||
|
trust_remote_code: Whether to trust remote code when loading model
|
||||||
|
config for layer-type introspection. Defaults to False for safety.
|
||||||
"""
|
"""
|
||||||
base_model_path = Path(base_model_path)
|
base_model_path = Path(base_model_path)
|
||||||
lora_adapter_path = Path(lora_adapter_path)
|
lora_adapter_path = Path(lora_adapter_path)
|
||||||
@@ -780,6 +1010,10 @@ def merge_lora_sharded_efficient(
|
|||||||
|
|
||||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||||
|
|
||||||
|
# Build layer type map via meta-device model introspection
|
||||||
|
layer_type_map = _build_layer_type_map(
|
||||||
|
base_model_path, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
unsupported_methods = []
|
unsupported_methods = []
|
||||||
|
|
||||||
# Check for AdaLoRA (Adaptive LoRA)
|
# Check for AdaLoRA (Adaptive LoRA)
|
||||||
@@ -904,6 +1138,7 @@ def merge_lora_sharded_efficient(
|
|||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
weight_renamings=weight_renamings,
|
weight_renamings=weight_renamings,
|
||||||
|
layer_type_map=layer_type_map,
|
||||||
)
|
)
|
||||||
merged_count += fused_merged
|
merged_count += fused_merged
|
||||||
|
|
||||||
@@ -926,6 +1161,7 @@ def merge_lora_sharded_efficient(
|
|||||||
nf4_double_quant=nf4_double_quant,
|
nf4_double_quant=nf4_double_quant,
|
||||||
use_dora=use_dora,
|
use_dora=use_dora,
|
||||||
weight_renamings=weight_renamings,
|
weight_renamings=weight_renamings,
|
||||||
|
layer_type_map=layer_type_map,
|
||||||
)
|
)
|
||||||
merged_tensors[key] = merged_tensor
|
merged_tensors[key] = merged_tensor
|
||||||
if was_merged:
|
if was_merged:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GCCallback,
|
GCCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
|
SkipEvalOnResumeCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.resume_from_checkpoint:
|
||||||
|
callbacks.append(SkipEvalOnResumeCallback())
|
||||||
|
|
||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,27 @@ class AxolotlTrainer(
|
|||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
|
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
|
||||||
|
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
|
||||||
|
# "the model handles num_items_in_batch internally" and skips the loss ÷
|
||||||
|
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
|
||||||
|
# (the gradient itself is still correct). Override to False when the model
|
||||||
|
# doesn't actually consume num_items_in_batch.
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
model_to_check = self.accelerator.unwrap_model(self.model)
|
||||||
|
if hasattr(model_to_check, "base_model"): # PEFT wrapper
|
||||||
|
model_to_check = model_to_check.base_model
|
||||||
|
if hasattr(model_to_check, "model"):
|
||||||
|
model_to_check = model_to_check.model
|
||||||
|
fwd = getattr(model_to_check, "forward", None)
|
||||||
|
if fwd is not None:
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
params = inspect.signature(fwd).parameters
|
||||||
|
if "num_items_in_batch" not in params:
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(
|
self._stored_metrics = defaultdict(
|
||||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||||
@@ -383,13 +404,29 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||||
# Inject zeros (= text token type) when not provided by the data collator.
|
# Inject zeros (= text token type) when not provided by the data collator.
|
||||||
|
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
|
||||||
|
_unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||||
if (
|
if (
|
||||||
"mm_token_type_ids" not in inputs
|
"mm_token_type_ids" not in inputs
|
||||||
and "input_ids" in inputs
|
and "input_ids" in inputs
|
||||||
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
and _model_type == "gemma4"
|
||||||
):
|
):
|
||||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||||
|
|
||||||
|
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
|
||||||
|
# from position_ids, but only when attention_mask is None. When sample
|
||||||
|
# packing is active the collator provides an all-ones attention_mask that
|
||||||
|
# prevents this detection — remove it so the model builds the correct
|
||||||
|
# per-sequence causal masks.
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and _model_type in ("gemma4", "gemma3")
|
||||||
|
and "attention_mask" in inputs
|
||||||
|
and "position_ids" in inputs
|
||||||
|
):
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
model,
|
model,
|
||||||
@@ -398,6 +435,23 @@ class AxolotlTrainer(
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Gemma4ForConditionalGeneration computes loss with a manual
|
||||||
|
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
|
||||||
|
# normalization and does redundant attention_mask filtering.
|
||||||
|
# Compute loss externally using the standard loss_function instead.
|
||||||
|
if _model_type == "gemma4" and "labels" in inputs:
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
vocab_size = unwrapped.config.get_text_config().vocab_size
|
||||||
|
loss = unwrapped.loss_function(
|
||||||
|
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
|
||||||
|
)
|
||||||
|
if return_outputs:
|
||||||
|
return loss, outputs
|
||||||
|
return loss
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -410,6 +464,21 @@ class AxolotlTrainer(
|
|||||||
LOG.info("Running evaluation step...")
|
LOG.info("Running evaluation step...")
|
||||||
return super().evaluate(*args, **kwargs)
|
return super().evaluate(*args, **kwargs)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||||
|
# Gemma4 requires mm_token_type_ids even during evaluation.
|
||||||
|
_unwrapped = self.accelerator.unwrap_model(model)
|
||||||
|
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||||
|
if (
|
||||||
|
"mm_token_type_ids" not in inputs
|
||||||
|
and "input_ids" in inputs
|
||||||
|
and _model_type == "gemma4"
|
||||||
|
):
|
||||||
|
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||||
|
return super().prediction_step(
|
||||||
|
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
concatenated_batch = {}
|
concatenated_batch = {}
|
||||||
|
|||||||
@@ -242,6 +242,85 @@ class ProducerConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _GroupShardedSampler:
|
||||||
|
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
|
||||||
|
|
||||||
|
``RepeatSampler`` yields ``num_generations`` consecutive copies of
|
||||||
|
each prompt, forming a GRPO group. For distributed training each
|
||||||
|
rank must see a disjoint slice of prompts (otherwise every rank
|
||||||
|
dogpiles on the first 1/world_size of the batch) while keeping each
|
||||||
|
group intact on a single rank so advantage normalization sees all
|
||||||
|
peer generations.
|
||||||
|
|
||||||
|
``accelerator.prepare(DataLoader)`` does not handle this correctly
|
||||||
|
for custom samplers with ``split_batches=False`` (the default): it
|
||||||
|
leaves the sampler alone and every rank replays identical indices.
|
||||||
|
This wrapper fixes that by consuming the inner sampler's full
|
||||||
|
output, chunking it into ``num_generations``-sized groups, and
|
||||||
|
round-robining whole groups across ranks.
|
||||||
|
|
||||||
|
Intended to be used ONLY when distributed training is active
|
||||||
|
(``num_replicas > 1``); for single-rank it is a no-op but still
|
||||||
|
correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner: Any,
|
||||||
|
num_generations: int,
|
||||||
|
rank: int,
|
||||||
|
num_replicas: int,
|
||||||
|
):
|
||||||
|
if num_generations < 1:
|
||||||
|
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
|
||||||
|
if num_replicas < 1:
|
||||||
|
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||||
|
if not (0 <= rank < num_replicas):
|
||||||
|
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
||||||
|
self.inner = inner
|
||||||
|
self.num_generations = num_generations
|
||||||
|
self.rank = rank
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
all_indices = list(self.inner)
|
||||||
|
if len(all_indices) % self.num_generations != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"inner sampler yielded {len(all_indices)} indices, "
|
||||||
|
f"not a multiple of num_generations={self.num_generations}"
|
||||||
|
)
|
||||||
|
# Chunk the flat index sequence into groups of num_generations
|
||||||
|
# consecutive indices. ``RepeatSampler`` guarantees that each
|
||||||
|
# group contains num_generations copies of the same prompt id.
|
||||||
|
groups = [
|
||||||
|
all_indices[i : i + self.num_generations]
|
||||||
|
for i in range(0, len(all_indices), self.num_generations)
|
||||||
|
]
|
||||||
|
# Round-robin whole groups across ranks. Round-robin (vs.
|
||||||
|
# contiguous chunking) preserves approximate shuffled order on
|
||||||
|
# each rank even when the group count is small relative to the
|
||||||
|
# world size.
|
||||||
|
for group in groups[self.rank :: self.num_replicas]:
|
||||||
|
yield from group
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
try:
|
||||||
|
inner_len = len(self.inner)
|
||||||
|
except TypeError:
|
||||||
|
# Non-sized inner sampler — we can't know the per-rank
|
||||||
|
# length without materializing. Return 0 as a hint that the
|
||||||
|
# DataLoader should fall back to iteration.
|
||||||
|
return 0
|
||||||
|
total_groups = inner_len // self.num_generations
|
||||||
|
# Ceiling division for the trailing groups that don't divide
|
||||||
|
# evenly — extra groups go to the first ``total_groups %
|
||||||
|
# num_replicas`` ranks, matching the round-robin above.
|
||||||
|
my_groups = (
|
||||||
|
total_groups + self.num_replicas - self.rank - 1
|
||||||
|
) // self.num_replicas
|
||||||
|
return my_groups * self.num_generations
|
||||||
|
|
||||||
|
|
||||||
class DataProducer(ABC):
|
class DataProducer(ABC):
|
||||||
"""Abstract base class for online data producers.
|
"""Abstract base class for online data producers.
|
||||||
|
|
||||||
@@ -556,6 +635,34 @@ class GRPODataProducer(BaseDataProducer):
|
|||||||
seed=self._seed,
|
seed=self._seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Shard the sampler across distributed ranks so each rank sees
|
||||||
|
# a disjoint slice of prompts. ``RepeatSampler`` groups each
|
||||||
|
# prompt with ``num_generations`` consecutive copies — our
|
||||||
|
# wrapper round-robins WHOLE groups across ranks so all
|
||||||
|
# generations of a given prompt stay on the same rank (needed
|
||||||
|
# for GRPO advantage normalization within a group).
|
||||||
|
#
|
||||||
|
# Without this, ``accelerator.prepare(dl)`` with the default
|
||||||
|
# ``split_batches=False`` leaves the custom sampler alone, so
|
||||||
|
# every rank iterates the identical index sequence and the
|
||||||
|
# cluster dogpiles on the first 1/world_size of the prompts.
|
||||||
|
num_replicas = max(1, trainer.accelerator.num_processes)
|
||||||
|
if num_replicas > 1:
|
||||||
|
sampler = _GroupShardedSampler(
|
||||||
|
inner=sampler,
|
||||||
|
num_generations=self._num_generations,
|
||||||
|
rank=trainer.accelerator.process_index,
|
||||||
|
num_replicas=num_replicas,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[RANK:%d] _GroupShardedSampler active "
|
||||||
|
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
|
||||||
|
trainer.accelerator.process_index,
|
||||||
|
num_replicas,
|
||||||
|
self._num_generations,
|
||||||
|
self._generation_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Use identity collator (same as stock GRPOTrainer)
|
# Use identity collator (same as stock GRPOTrainer)
|
||||||
def _identity(x):
|
def _identity(x):
|
||||||
return x
|
return x
|
||||||
@@ -574,12 +681,11 @@ class GRPODataProducer(BaseDataProducer):
|
|||||||
rank=trainer.args.process_index,
|
rank=trainer.args.process_index,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._prompt_dl = trainer.accelerator.prepare(dl)
|
# Skip accelerator.prepare — we're handling per-rank sharding
|
||||||
|
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
|
||||||
# Don't let accelerator track this dataloader
|
# otherwise try to wrap the DataLoader with its own sharding
|
||||||
acc_dls = trainer.accelerator._dataloaders
|
# logic which does not understand our group structure.
|
||||||
if self._prompt_dl in acc_dls:
|
self._prompt_dl = dl
|
||||||
acc_dls.remove(self._prompt_dl)
|
|
||||||
|
|
||||||
self._prompt_iter = iter(self._prompt_dl)
|
self._prompt_iter = iter(self._prompt_dl)
|
||||||
|
|
||||||
@@ -1103,11 +1209,22 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
||||||
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
||||||
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
||||||
|
|
||||||
|
This is the canonical sync trigger and runs in BOTH async and
|
||||||
|
synchronous modes from ``_prepare_inputs_with_data_producer`` /
|
||||||
|
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
|
||||||
|
patch is a parallel backup for non-data-producer paths (vanilla
|
||||||
|
GRPO without NeMo Gym), where the data producer is bypassed
|
||||||
|
entirely and TRL's stock generate-then-sync flow is used instead.
|
||||||
"""
|
"""
|
||||||
if not (self.use_vllm and self.args.async_prefetch):
|
if not self.use_vllm:
|
||||||
return
|
return
|
||||||
step = self.state.global_step
|
step = self.state.global_step
|
||||||
interval = self.args.vllm_sync_interval
|
# Default to syncing every step when no interval is configured —
|
||||||
|
# otherwise ``step % None`` would TypeError, and the previous
|
||||||
|
# behavior of crashing on the first sync was strictly worse than
|
||||||
|
# the standard "sync every optimizer step".
|
||||||
|
interval = self.args.vllm_sync_interval or 1
|
||||||
if step != self._last_synced_step and step % interval == 0:
|
if step != self._last_synced_step and step % interval == 0:
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||||
@@ -1202,13 +1319,42 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
|
|
||||||
# Permanently replace vllm_generation.sync_weights with our custom
|
# Permanently replace vllm_generation.sync_weights with our custom
|
||||||
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
||||||
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
|
#
|
||||||
# handles the sync with proper interval tracking.
|
# The design has two modes that have to be threaded carefully:
|
||||||
|
#
|
||||||
|
# - Async prefetch ON: BG generation thread can't safely call
|
||||||
|
# sync_weights mid-rollout (it races with the trainer's optimizer
|
||||||
|
# step and can corrupt weights). We no-op the stock sync hook and
|
||||||
|
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
|
||||||
|
# optimizer step on the main thread.
|
||||||
|
#
|
||||||
|
# - Async prefetch OFF (synchronous mode): TRL's stock
|
||||||
|
# ``_generate_single_turn`` calls ``sync_weights`` once per step
|
||||||
|
# boundary. There's no BG thread to race with, and
|
||||||
|
# ``_maybe_sync_vllm_weights`` short-circuits with
|
||||||
|
# ``if not async_prefetch: return``, so we MUST wire the stock
|
||||||
|
# hook directly to our LoRA sync helper — otherwise nothing ever
|
||||||
|
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
|
||||||
|
# keeps serving the base model, every rollout in every group
|
||||||
|
# produces identical outputs, advantages are zero, optimizer
|
||||||
|
# step gets skipped, repeat).
|
||||||
if not getattr(self, "_patched_sync_weights", False):
|
if not getattr(self, "_patched_sync_weights", False):
|
||||||
if self.use_vllm and hasattr(self, "vllm_generation"):
|
if self.use_vllm and hasattr(self, "vllm_generation"):
|
||||||
if getattr(self.args, "vllm_lora_sync", False):
|
if getattr(self.args, "vllm_lora_sync", False):
|
||||||
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
|
if getattr(self.args, "async_prefetch", False):
|
||||||
self.vllm_generation.sync_weights = lambda: None
|
# Async: drive sync from main thread via
|
||||||
|
# _maybe_sync_vllm_weights instead.
|
||||||
|
self.vllm_generation.sync_weights = lambda: None
|
||||||
|
else:
|
||||||
|
# Sync mode: TRL's _generate_single_turn already
|
||||||
|
# calls sync_weights once per step boundary. Wire
|
||||||
|
# it directly to our LoRA filesystem sync helper.
|
||||||
|
sync_helper = self._sync_lora_adapter
|
||||||
|
|
||||||
|
def _lora_filesystem_sync():
|
||||||
|
sync_helper()
|
||||||
|
|
||||||
|
self.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||||
self._patched_sync_weights = True
|
self._patched_sync_weights = True
|
||||||
else:
|
else:
|
||||||
from accelerate.utils import is_peft_model
|
from accelerate.utils import is_peft_model
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -44,6 +44,7 @@ plugins:
|
|||||||
- gemma3_text
|
- gemma3_text
|
||||||
- gemma3n
|
- gemma3n
|
||||||
- gemma3n_text
|
- gemma3n_text
|
||||||
|
- gemma4
|
||||||
- glm
|
- glm
|
||||||
- glm4
|
- glm4
|
||||||
- glm4_moe
|
- glm4_moe
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
'`pip uninstall -y cut-cross-entropy && pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
|
|||||||
|
|
||||||
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
||||||
|
|
||||||
**Important limitations:**
|
|
||||||
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
|
|
||||||
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
|
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||||
|
|||||||
@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def warn_sonicmoe_lora_overhead(cls, data):
|
|
||||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
|
||||||
"lora",
|
|
||||||
"qlora",
|
|
||||||
):
|
|
||||||
lora_target = data.get("lora_target_modules") or []
|
|
||||||
lora_linear = data.get("lora_target_linear_modules") or []
|
|
||||||
targets = (
|
|
||||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
|
||||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
|
||||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
|
||||||
if any(kw in t for t in targets for kw in expert_keywords):
|
|
||||||
LOG.info(
|
|
||||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
|
||||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
|
||||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def disable_mlp_kernel(cls, data):
|
def disable_mlp_kernel(cls, data):
|
||||||
|
|||||||
@@ -2,17 +2,35 @@
|
|||||||
# Copyright (c) Axolotl AI
|
# Copyright (c) Axolotl AI
|
||||||
# Licensed under the Apache License, Version 2.0
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
from . import layers
|
from .lora_layout import (
|
||||||
from .lora_ops import ParallelExperts
|
peft_down_proj_lora_to_scattermoe,
|
||||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
peft_lora_B_to_scattermoe,
|
||||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
peft_lora_to_scattermoe,
|
||||||
|
validate_scattermoe_lora_shapes,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"layers",
|
"peft_down_proj_lora_to_scattermoe",
|
||||||
"ParallelExperts",
|
"peft_lora_B_to_scattermoe",
|
||||||
"flatten_sort_count",
|
"peft_lora_to_scattermoe",
|
||||||
"parallel_linear",
|
"validate_scattermoe_lora_shapes",
|
||||||
"ScatterMoELoRA",
|
|
||||||
"parallel_linear_lora",
|
|
||||||
"lora_ops",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import layers
|
||||||
|
from .lora_ops import ParallelExperts
|
||||||
|
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||||
|
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
if exc.name != "triton":
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
__all__ += [
|
||||||
|
"layers",
|
||||||
|
"ParallelExperts",
|
||||||
|
"flatten_sort_count",
|
||||||
|
"parallel_linear",
|
||||||
|
"ScatterMoELoRA",
|
||||||
|
"parallel_linear_lora",
|
||||||
|
"lora_ops",
|
||||||
|
]
|
||||||
|
|||||||
@@ -35,81 +35,19 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .lora_layout import (
|
||||||
|
peft_down_proj_lora_to_scattermoe,
|
||||||
|
peft_lora_B_to_scattermoe,
|
||||||
|
peft_lora_to_scattermoe,
|
||||||
|
)
|
||||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||||
|
|
||||||
# =============================================================================
|
__all__ = [
|
||||||
# LoRA layout conversion utilities (peft <-> scattermoe)
|
"peft_down_proj_lora_to_scattermoe",
|
||||||
# =============================================================================
|
"peft_lora_B_to_scattermoe",
|
||||||
|
"peft_lora_to_scattermoe",
|
||||||
|
]
|
||||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
|
||||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
|
||||||
expert-major ``[N, r*E]``.
|
|
||||||
|
|
||||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
|
||||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
|
||||||
"""
|
|
||||||
N = peft_B.shape[0]
|
|
||||||
return (
|
|
||||||
peft_B.reshape(N, rank, num_experts)
|
|
||||||
.permute(0, 2, 1)
|
|
||||||
.contiguous()
|
|
||||||
.reshape(N, num_experts * rank)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
||||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
|
||||||
|
|
||||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
|
||||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
|
||||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
|
||||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
|
||||||
are swapped relative to scattermoe's convention.
|
|
||||||
|
|
||||||
peft gives:
|
|
||||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
|
||||||
|
|
||||||
scattermoe needs:
|
|
||||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
|
||||||
|
|
||||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
|
||||||
Uses vectorized tensor operations (no Python loop over experts).
|
|
||||||
|
|
||||||
Works for **both** gate_up_proj and down_proj since the transposition
|
|
||||||
issue is the same for any parameter.
|
|
||||||
"""
|
|
||||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
|
||||||
|
|
||||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
|
||||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
|
||||||
|
|
||||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
|
||||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
|
||||||
smoe_A = (
|
|
||||||
peft_B_em.reshape(dim2, num_experts, rank)
|
|
||||||
.permute(1, 2, 0)
|
|
||||||
.contiguous()
|
|
||||||
.reshape(rank * num_experts, dim2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
|
||||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
|
||||||
smoe_B = (
|
|
||||||
peft_A.reshape(num_experts, rank, dim1)
|
|
||||||
.permute(2, 0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.reshape(dim1, num_experts * rank)
|
|
||||||
)
|
|
||||||
|
|
||||||
return smoe_A, smoe_B
|
|
||||||
|
|
||||||
|
|
||||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
||||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
|
||||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# ParamWrapper unwrapping
|
# ParamWrapper unwrapping
|
||||||
@@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
if gup is not None:
|
if gup is not None:
|
||||||
num_experts = gup.shape[0]
|
num_experts = gup.shape[0]
|
||||||
|
|
||||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
# Extract gate_up_proj LoRA
|
||||||
gup_lora = None
|
gup_lora = None
|
||||||
gup_wrapper = wrappers.get("gate_up_proj")
|
gup_wrapper = wrappers.get("gate_up_proj")
|
||||||
if gup_wrapper is not None:
|
if gup_wrapper is not None:
|
||||||
@@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
rank = lora_A.shape[0] // num_experts
|
rank = lora_A.shape[0] // num_experts
|
||||||
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||||
|
|
||||||
# Extract down_proj LoRA (needs A<->B swap due to transposition)
|
# Extract down_proj LoRA
|
||||||
down_lora = None
|
down_lora = None
|
||||||
down_wrapper = wrappers.get("down_proj")
|
down_wrapper = wrappers.get("down_proj")
|
||||||
if down_wrapper is not None:
|
if down_wrapper is not None:
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Pure tensor layout helpers for ScatterMoE LoRA weights."""
|
||||||
|
|
||||||
|
|
||||||
|
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||||
|
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||||
|
expert-major ``[N, r*E]``.
|
||||||
|
|
||||||
|
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||||
|
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||||
|
"""
|
||||||
|
N = peft_B.shape[0]
|
||||||
|
return (
|
||||||
|
peft_B.reshape(N, rank, num_experts)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.contiguous()
|
||||||
|
.reshape(N, num_experts * rank)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
|
"""Convert peft LoRA weights to scattermoe layout.
|
||||||
|
|
||||||
|
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||||
|
where ``out_features=dim1, in_features=dim2``. ScatterMoE transposes the
|
||||||
|
parameter (``W = param.transpose(2, 1)``), giving ``[E, dim2, dim1]`` with
|
||||||
|
``K=dim2, N=dim1``.
|
||||||
|
|
||||||
|
peft gives:
|
||||||
|
lora_A ``[r*E, dim2]``, lora_B ``[dim1, r*E]``
|
||||||
|
|
||||||
|
scattermoe needs:
|
||||||
|
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||||
|
|
||||||
|
peft's A already matches ScatterMoE's A shape. Only B needs conversion from
|
||||||
|
peft's rank-major layout to ScatterMoE's expert-major layout.
|
||||||
|
"""
|
||||||
|
smoe_A = peft_A
|
||||||
|
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||||
|
|
||||||
|
return smoe_A, smoe_B
|
||||||
|
|
||||||
|
|
||||||
|
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
|
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||||
|
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B):
|
||||||
|
"""Validate LoRA tensor layout before dispatching ScatterMoE kernels."""
|
||||||
|
E, K, N = expert_weights.shape
|
||||||
|
if lora_A.dim() != 2 or lora_B.dim() != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"ScatterMoE LoRA expects 2D lora_A and lora_B tensors, got "
|
||||||
|
f"lora_A={tuple(lora_A.shape)} and lora_B={tuple(lora_B.shape)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if lora_A.size(0) % E != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"ScatterMoE LoRA expects lora_A rows to be divisible by the number "
|
||||||
|
f"of experts ({E}), got lora_A={tuple(lora_A.shape)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
rank = lora_A.size(0) // E
|
||||||
|
expected_A = (E * rank, K)
|
||||||
|
expected_B = (N, E * rank)
|
||||||
|
if tuple(lora_A.shape) != expected_A or tuple(lora_B.shape) != expected_B:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid ScatterMoE LoRA layout for expert_weights "
|
||||||
|
f"{tuple(expert_weights.shape)}. Expected lora_A={expected_A} and "
|
||||||
|
f"lora_B={expected_B}, got lora_A={tuple(lora_A.shape)} and "
|
||||||
|
f"lora_B={tuple(lora_B.shape)}. For PEFT target_parameters, keep "
|
||||||
|
"lora_A as [E*r, K] and only convert lora_B from rank-major to "
|
||||||
|
"expert-major layout."
|
||||||
|
)
|
||||||
@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
|
|||||||
scatter2scatter_lora,
|
scatter2scatter_lora,
|
||||||
scatter2scatter_lora_dX,
|
scatter2scatter_lora_dX,
|
||||||
)
|
)
|
||||||
|
from .lora_layout import validate_scattermoe_lora_shapes
|
||||||
|
|
||||||
|
|
||||||
class ScatterMoELoRA(torch.autograd.Function):
|
class ScatterMoELoRA(torch.autograd.Function):
|
||||||
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
|
|||||||
return lora_A, lora_B, scaling
|
return lora_A, lora_B, scaling
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Drop-in replacement for parallel_linear
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_linear_lora(
|
def parallel_linear_lora(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
expert_weights: torch.Tensor,
|
expert_weights: torch.Tensor,
|
||||||
@@ -451,6 +447,7 @@ def parallel_linear_lora(
|
|||||||
Otherwise falls back to standard scatter2scatter.
|
Otherwise falls back to standard scatter2scatter.
|
||||||
"""
|
"""
|
||||||
if lora_A is not None and lora_B is not None:
|
if lora_A is not None and lora_B is not None:
|
||||||
|
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
|
||||||
return ScatterMoELoRA.apply(
|
return ScatterMoELoRA.apply(
|
||||||
inputs,
|
inputs,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
|
|||||||
@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
swiglu=cfg.liger_glu_activation,
|
swiglu=cfg.liger_glu_activation,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||||
|
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
|
||||||
|
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
|
||||||
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
|
from transformers.models.gemma4 import modeling_gemma4
|
||||||
|
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
|
||||||
|
|
||||||
|
class _LigerGemma4RMSNorm(LigerRMSNorm):
|
||||||
|
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
|
||||||
|
|
||||||
|
def __new__(cls, dim, eps=1e-6, with_scale=True):
|
||||||
|
if not with_scale:
|
||||||
|
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-6, with_scale=True):
|
||||||
|
if not with_scale:
|
||||||
|
return
|
||||||
|
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
|
||||||
|
super().__init__(
|
||||||
|
dim, eps, offset=0.0, casting_mode="llama", in_place=False
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
|
||||||
|
class _LigerGemma4MLP(LigerGEGLUMLP):
|
||||||
|
def __init__(self, config, layer_idx=None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
|
||||||
|
if cfg.liger_rope:
|
||||||
|
LOG.warning(
|
||||||
|
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
|
||||||
|
)
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
LOG.warning(
|
||||||
|
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Applied Liger kernels for gemma4: "
|
||||||
|
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
|
||||||
|
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
|
||||||
|
)
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
try:
|
try:
|
||||||
from .models.base import patch_lce_forward
|
from .models.base import patch_lce_forward
|
||||||
|
|||||||
@@ -110,11 +110,36 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
item["agent_ref"] = full_item["agent_ref"]
|
item["agent_ref"] = full_item["agent_ref"]
|
||||||
dataset_items.append(item)
|
dataset_items.append(item)
|
||||||
|
|
||||||
# Expand by num_generations (agent produces one rollout per call)
|
# NOTE: do NOT re-expand by num_generations here.
|
||||||
expanded_items = []
|
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
|
||||||
for item in dataset_items:
|
# yields ``num_generations`` consecutive copies of each unique
|
||||||
for _ in range(self._num_generations):
|
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
|
||||||
expanded_items.append(item)
|
# num_generations)`` items — one entry per rollout. Expanding
|
||||||
|
# again here would fire ``num_generations^2`` rollouts per
|
||||||
|
# prompt per rank and make every step dogpile on a handful of
|
||||||
|
# tasks.
|
||||||
|
expanded_items = dataset_items
|
||||||
|
|
||||||
|
# Diagnostic: log what this rank is about to fire.
|
||||||
|
try:
|
||||||
|
import collections
|
||||||
|
|
||||||
|
iid_counts: collections.Counter[str | None] = collections.Counter()
|
||||||
|
for it in dataset_items:
|
||||||
|
iid_counts[
|
||||||
|
(it.get("responses_create_params", {}).get("metadata") or {}).get(
|
||||||
|
"instance_id"
|
||||||
|
)
|
||||||
|
] += 1
|
||||||
|
LOG.info(
|
||||||
|
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
|
||||||
|
trainer.accelerator.process_index,
|
||||||
|
len(dataset_items),
|
||||||
|
len(iid_counts),
|
||||||
|
list(iid_counts.most_common(5)),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Call NeMo Gym agents
|
# Call NeMo Gym agents
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
@@ -140,6 +165,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
logprobs_list = []
|
logprobs_list = []
|
||||||
rewards_list = []
|
rewards_list = []
|
||||||
|
|
||||||
|
num_turns_list: list[int] = []
|
||||||
for resp in responses:
|
for resp in responses:
|
||||||
parsed = _parse_agent_response(resp, eos_token_id)
|
parsed = _parse_agent_response(resp, eos_token_id)
|
||||||
prompt_ids_list.append(parsed["prompt_ids"])
|
prompt_ids_list.append(parsed["prompt_ids"])
|
||||||
@@ -147,6 +173,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
env_mask_list.append(parsed["env_mask"])
|
env_mask_list.append(parsed["env_mask"])
|
||||||
logprobs_list.append(parsed["logprobs"])
|
logprobs_list.append(parsed["logprobs"])
|
||||||
rewards_list.append(parsed["reward"])
|
rewards_list.append(parsed["reward"])
|
||||||
|
num_turns_list.append(parsed.get("num_turns", 0))
|
||||||
|
|
||||||
# Pad to tensors
|
# Pad to tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@@ -179,22 +206,48 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||||
|
|
||||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
# Inject per-rollout reward + num_turns into each input. Since
|
||||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
# ``RepeatSampler`` already yields ``num_generations`` copies of
|
||||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
# each prompt, ``inputs`` has ONE entry per rollout (matching
|
||||||
|
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
|
||||||
|
# GRPO advantage normalization is the trainer's job downstream.
|
||||||
|
assert len(inputs) == len(rewards_list), (
|
||||||
|
f"rewards/inputs length mismatch: "
|
||||||
|
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
|
||||||
|
)
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
# Each input gets rewards for its num_generations rollouts
|
inp["env_reward"] = rewards_list[i]
|
||||||
start = i * self._num_generations
|
inp["num_turns"] = num_turns_list[i]
|
||||||
end = start + self._num_generations
|
|
||||||
inp["env_reward"] = rewards_list[start:end]
|
|
||||||
|
|
||||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
# One expanded_input per rollout (already correct count because
|
||||||
expanded_inputs = []
|
# inputs has num_generations copies baked in by the sampler).
|
||||||
for inp in inputs:
|
expanded_inputs = [dict(inp) for inp in inputs]
|
||||||
for g in range(self._num_generations):
|
|
||||||
expanded_inp = dict(inp)
|
# Log rollout-level stats to wandb from rank 0. These are the
|
||||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
# true agent-side metrics (not the tokenized TRL view) — so
|
||||||
expanded_inputs.append(expanded_inp)
|
# num_turns reflects how many /run iterations each rollout
|
||||||
|
# actually took before finishing or hitting max_turns.
|
||||||
|
if is_main and num_turns_list:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
if wandb.run is not None:
|
||||||
|
import statistics as _stats
|
||||||
|
|
||||||
|
nonzero = sum(1 for r in rewards_list if r > 0)
|
||||||
|
log_payload = {
|
||||||
|
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
|
||||||
|
"rollout/num_turns/min": float(min(num_turns_list)),
|
||||||
|
"rollout/num_turns/max": float(max(num_turns_list)),
|
||||||
|
"rollout/reward/mean": float(_stats.mean(rewards_list)),
|
||||||
|
"rollout/reward/nonzero_frac": (
|
||||||
|
nonzero / len(rewards_list) if rewards_list else 0.0
|
||||||
|
),
|
||||||
|
"rollout/n_samples": float(len(rewards_list)),
|
||||||
|
}
|
||||||
|
wandb.log(log_payload, commit=False)
|
||||||
|
except Exception as exc: # never let metric logging break training
|
||||||
|
LOG.warning("rollout wandb log failed: %s", exc)
|
||||||
|
|
||||||
# Decode completions for reward functions
|
# Decode completions for reward functions
|
||||||
completions = trainer.processing_class.batch_decode(
|
completions = trainer.processing_class.batch_decode(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ Supports two modes:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- vLLM weight-sync transport probe ------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VLLMWeightSyncCapabilities:
|
||||||
|
"""What weight-sync routes a vLLM server actually exposes.
|
||||||
|
|
||||||
|
Discovered once at ``pre_model_load`` time by fetching the server's
|
||||||
|
``/openapi.json``. Drives the transport-selection table below.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nccl: bool = False # /init_communicator/ + /update_named_param/
|
||||||
|
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
|
||||||
|
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
|
||||||
|
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
|
||||||
|
probed: bool = False
|
||||||
|
probe_error: str | None = None
|
||||||
|
routes: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def any_full_param_sync(self) -> bool:
|
||||||
|
"""True if at least one transport can push full-model weights."""
|
||||||
|
return self.nccl or self.http_full
|
||||||
|
|
||||||
|
@property
|
||||||
|
def any_lora_sync(self) -> bool:
|
||||||
|
"""True if at least one transport can push LoRA adapters."""
|
||||||
|
return self.lora_filesystem or self.lora_axolotl or self.nccl
|
||||||
|
|
||||||
|
|
||||||
|
def probe_vllm_weight_sync(
|
||||||
|
base_url: str, timeout: float = 5.0
|
||||||
|
) -> VLLMWeightSyncCapabilities:
|
||||||
|
"""Detect which weight-sync routes the configured vLLM server exposes.
|
||||||
|
|
||||||
|
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
|
||||||
|
we care about is mounted as a POST route there. Falls back to all-False
|
||||||
|
on any error so the caller can still decide what to do (typically: raise
|
||||||
|
a clear error rather than silently no-op).
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
caps = VLLMWeightSyncCapabilities()
|
||||||
|
try:
|
||||||
|
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
|
||||||
|
r.raise_for_status()
|
||||||
|
spec = r.json()
|
||||||
|
routes = sorted((spec.get("paths") or {}).keys())
|
||||||
|
caps.routes = routes
|
||||||
|
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
|
||||||
|
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
|
||||||
|
caps.lora_axolotl = "/set_lora_adapter/" in routes
|
||||||
|
caps.http_full = "/http_update_weights/" in routes
|
||||||
|
caps.probed = True
|
||||||
|
except Exception as exc:
|
||||||
|
caps.probe_error = f"{type(exc).__name__}: {exc}"
|
||||||
|
LOG.warning(
|
||||||
|
"NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. "
|
||||||
|
"Will fall back to LoRA-only behavior.",
|
||||||
|
base_url,
|
||||||
|
caps.probe_error,
|
||||||
|
)
|
||||||
|
return caps
|
||||||
|
|
||||||
|
|
||||||
|
def select_weight_sync_transport(
|
||||||
|
caps: VLLMWeightSyncCapabilities,
|
||||||
|
*,
|
||||||
|
has_lora: bool,
|
||||||
|
vllm_lora_sync_pref: bool,
|
||||||
|
) -> str:
|
||||||
|
"""Pick the right transport for a (server caps, model type) combo.
|
||||||
|
|
||||||
|
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
|
||||||
|
``"none"``. The caller decides what to do with ``"none"`` (typically:
|
||||||
|
raise an error explaining the misconfiguration).
|
||||||
|
|
||||||
|
Selection table:
|
||||||
|
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
|
||||||
|
LoRA model + lora endpoint → lora_filesystem
|
||||||
|
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
|
||||||
|
Full model + nccl endpoint → nccl
|
||||||
|
Full model + http endpoint → http_full
|
||||||
|
anything else → none
|
||||||
|
"""
|
||||||
|
if has_lora:
|
||||||
|
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
|
||||||
|
return "lora_filesystem"
|
||||||
|
if caps.lora_filesystem or caps.lora_axolotl:
|
||||||
|
return "lora_filesystem"
|
||||||
|
if caps.nccl:
|
||||||
|
return "nccl"
|
||||||
|
return "none"
|
||||||
|
# Full-parameter model
|
||||||
|
if caps.nccl:
|
||||||
|
return "nccl"
|
||||||
|
if caps.http_full:
|
||||||
|
return "http_full"
|
||||||
|
return "none"
|
||||||
|
|
||||||
|
|
||||||
class NemoGymPlugin(BasePlugin):
|
class NemoGymPlugin(BasePlugin):
|
||||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||||
|
|
||||||
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
|
|||||||
self._reward_fn = None
|
self._reward_fn = None
|
||||||
self._dataset_lookup = None
|
self._dataset_lookup = None
|
||||||
self._agent_servers = {}
|
self._agent_servers = {}
|
||||||
|
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Apply monkeypatches before trainer creation."""
|
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
|
||||||
|
|
||||||
|
Replaces the previous unconditional ``init_communicator`` monkey-patch
|
||||||
|
with a probe of the configured vLLM server's ``/openapi.json``. We only
|
||||||
|
bypass NCCL init when the server we're talking to actually lacks the
|
||||||
|
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
|
||||||
|
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
|
||||||
|
standard TRL flow alone so full-finetune training can sync weights.
|
||||||
|
"""
|
||||||
if not cfg.nemo_gym_enabled:
|
if not cfg.nemo_gym_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Always skip NCCL communicator init in NeMo Gym mode.
|
|
||||||
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
|
||||||
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
|
||||||
trl_cfg = getattr(cfg, "trl", None)
|
trl_cfg = getattr(cfg, "trl", None)
|
||||||
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
|
||||||
|
return
|
||||||
|
|
||||||
|
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
|
||||||
|
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
|
||||||
|
base_url = f"http://{host}:{port}"
|
||||||
|
self._vllm_caps = probe_vllm_weight_sync(base_url)
|
||||||
|
|
||||||
|
if self._vllm_caps.probed:
|
||||||
|
LOG.info(
|
||||||
|
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
|
||||||
|
"lora_axolotl=%s http_full=%s",
|
||||||
|
base_url,
|
||||||
|
self._vllm_caps.nccl,
|
||||||
|
self._vllm_caps.lora_filesystem,
|
||||||
|
self._vllm_caps.lora_axolotl,
|
||||||
|
self._vllm_caps.http_full,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
|
||||||
|
# available we leave VLLMClient.init_communicator alone so the
|
||||||
|
# standard TRL sync flow can run for full-parameter training.
|
||||||
|
if not self._vllm_caps.nccl:
|
||||||
self._patch_skip_nccl_init()
|
self._patch_skip_nccl_init()
|
||||||
|
|
||||||
def _patch_skip_nccl_init(self):
|
def _patch_skip_nccl_init(self):
|
||||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||||
|
|
||||||
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
Only called when the configured vLLM server doesn't expose
|
||||||
serve script). The NCCL communicator is not needed and fails with both
|
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
|
||||||
vLLM V1 engine and standard OpenAI server mode.
|
TRL's standard ``init_communicator`` would 404 inside trainer
|
||||||
|
construction; we no-op it so the LoRA filesystem path can install
|
||||||
|
its own sync in ``post_trainer_create``.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from trl.generation.vllm_client import VLLMClient
|
from trl.generation.vllm_client import VLLMClient
|
||||||
|
|
||||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||||
"Skipping NCCL init_communicator (LoRA sync mode)"
|
"Skipping NCCL init_communicator (server has no /init_communicator/)"
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
|
||||||
)
|
)
|
||||||
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||||
|
|
||||||
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
|
|||||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||||
|
|
||||||
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
# Pick a weight-sync transport based on what the configured vLLM
|
||||||
# - Install LoRA sync (when vllm_lora_sync=True)
|
# server actually exposes (see ``pre_model_load`` probe) and what
|
||||||
# - Or no-op sync_weights (when using standard vLLM server)
|
# kind of model we're training. The selection table is documented
|
||||||
|
# in ``select_weight_sync_transport``.
|
||||||
trl_cfg = getattr(cfg, "trl", None)
|
trl_cfg = getattr(cfg, "trl", None)
|
||||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||||
vllm_gen = trainer.vllm_generation
|
vllm_gen = trainer.vllm_generation
|
||||||
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
adapter = getattr(cfg, "adapter", None)
|
||||||
|
has_lora = adapter in ("lora", "qlora")
|
||||||
|
vllm_lora_sync_pref = bool(
|
||||||
|
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
|
||||||
|
)
|
||||||
|
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
|
||||||
|
transport = select_weight_sync_transport(
|
||||||
|
caps,
|
||||||
|
has_lora=has_lora,
|
||||||
|
vllm_lora_sync_pref=vllm_lora_sync_pref,
|
||||||
|
)
|
||||||
|
|
||||||
|
if transport == "lora_filesystem":
|
||||||
self._setup_lora_sync(trainer)
|
self._setup_lora_sync(trainer)
|
||||||
# Verify the vLLM server supports runtime LoRA loading
|
|
||||||
self._check_lora_endpoint(vllm_gen)
|
self._check_lora_endpoint(vllm_gen)
|
||||||
else:
|
LOG.info("NeMo Gym weight sync: LoRA filesystem")
|
||||||
# No NCCL, no LoRA sync — skip all weight sync paths
|
elif transport == "nccl":
|
||||||
vllm_gen.sync_weights = lambda: LOG.debug(
|
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
|
||||||
"Weight sync skipped (NeMo Gym mode)"
|
# alone (pre_model_load only patched it when the probe found no
|
||||||
|
# NCCL route) so the trainer's normal weight-sync flow runs.
|
||||||
|
LOG.info(
|
||||||
|
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
|
||||||
)
|
)
|
||||||
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
elif transport == "http_full":
|
||||||
"Weight sync skipped (NeMo Gym mode)"
|
# Full-parameter HTTP sync — implementation lands in step 3.
|
||||||
|
# For now, fail loudly so users know the path is detected but
|
||||||
|
# not yet wired up, instead of silently no-oping like before.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
|
||||||
|
"but the client-side sync helper is not yet implemented "
|
||||||
|
"(planned). Use `adapter: lora|qlora` for now, or use a "
|
||||||
|
"vLLM serve module that exposes /init_communicator/ for "
|
||||||
|
"NCCL sync."
|
||||||
)
|
)
|
||||||
# Also patch the async trainer's internal sync method
|
else: # transport == "none"
|
||||||
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
# No viable sync path. Build a precise error so the user knows
|
||||||
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
# exactly what's missing and how to fix it.
|
||||||
"Async weight sync skipped (NeMo Gym mode)"
|
if not caps.probed:
|
||||||
|
msg = (
|
||||||
|
"could not probe the vLLM server's "
|
||||||
|
f"/openapi.json: {caps.probe_error}. "
|
||||||
|
"Verify that vLLM is reachable at "
|
||||||
|
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
|
||||||
|
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
|
||||||
)
|
)
|
||||||
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
elif has_lora:
|
||||||
|
msg = (
|
||||||
|
"the vLLM server has neither NCCL routes "
|
||||||
|
"(/init_communicator/) nor a LoRA-loading route "
|
||||||
|
"(/v1/load_lora_adapter or /set_lora_adapter/). "
|
||||||
|
"Restart vLLM with `--enable-lora --max-lora-rank N "
|
||||||
|
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
|
||||||
|
"server, or use `axolotl vllm-serve` for the "
|
||||||
|
"NCCL-capable serve module."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
"the vLLM server exposes no full-parameter sync route "
|
||||||
|
"(/init_communicator/ for NCCL or /http_update_weights/ "
|
||||||
|
"for HTTP). Use `axolotl vllm-serve` (which has both) "
|
||||||
|
"or set `adapter: lora|qlora`."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
|
||||||
|
"weight sync the trainer's gradient updates never reach the "
|
||||||
|
"rollout policy (functionally a no-op trainer)."
|
||||||
|
)
|
||||||
|
|
||||||
if multi_turn:
|
if multi_turn:
|
||||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||||
|
|||||||
@@ -130,21 +130,41 @@ def start_servers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_server_configs(head_port: int = 11000) -> dict:
|
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
|
||||||
"""Fetch the global config from the NeMo Gym head server.
|
"""Fetch the global config from the NeMo Gym head server.
|
||||||
|
|
||||||
|
Retries up to 3 times with exponential backoff. The default per-attempt
|
||||||
|
timeout is 30s (raised from the original 5s) because head servers can
|
||||||
|
be slow to respond when they're concurrently serving rollouts from a
|
||||||
|
prior training run. A 5s timeout was empirically too tight to survive
|
||||||
|
a kill-and-relaunch cycle.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping server_name -> server config.
|
Dict mapping server_name -> server config.
|
||||||
"""
|
"""
|
||||||
response = requests.get(
|
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
|
||||||
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
last_exc: Exception | None = None
|
||||||
|
for attempt in (1, 2, 3):
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=timeout)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = yaml.safe_load(response.text)
|
||||||
|
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = yaml.safe_load(result)
|
||||||
|
return result
|
||||||
|
except (requests.exceptions.RequestException, OSError) as exc:
|
||||||
|
last_exc = exc
|
||||||
|
LOG.warning(
|
||||||
|
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
|
||||||
|
attempt,
|
||||||
|
type(exc).__name__,
|
||||||
|
)
|
||||||
|
if attempt < 3:
|
||||||
|
time.sleep(2.0 * attempt)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
result = yaml.safe_load(response.text)
|
|
||||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = yaml.safe_load(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_servers(
|
def get_agent_servers(
|
||||||
|
|||||||
593
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
593
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
"""
|
||||||
|
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
|
||||||
|
|
||||||
|
Fuses three operations into one kernel launch:
|
||||||
|
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
|
||||||
|
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||||
|
3. (optional) RMSNorm without scale (for v_norm)
|
||||||
|
|
||||||
|
This eliminates two intermediate tensor materializations per Q/K path;
|
||||||
|
churn from rotate_half / apply_rotary_pos_emb.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
|
||||||
|
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
|
||||||
|
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
|
||||||
|
sin: (rows, head_dim) — same as cos
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from liger_kernel.ops.utils import (
|
||||||
|
calculate_settings,
|
||||||
|
compare_version,
|
||||||
|
ensure_contiguous,
|
||||||
|
torch_to_triton_dtype,
|
||||||
|
)
|
||||||
|
from liger_kernel.utils import is_npu_available
|
||||||
|
|
||||||
|
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||||
|
try:
|
||||||
|
from triton.language.extra.libdevice import rsqrt
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from triton.language.extra.cuda.libdevice import rsqrt
|
||||||
|
else:
|
||||||
|
from triton.language.math import rsqrt
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_rope_forward_kernel(
|
||||||
|
Y_ptr,
|
||||||
|
Y_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
W_ptr,
|
||||||
|
COS_ptr,
|
||||||
|
COS_row_stride,
|
||||||
|
SIN_ptr,
|
||||||
|
SIN_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
n_rot,
|
||||||
|
n_heads,
|
||||||
|
eps,
|
||||||
|
HAS_WEIGHT: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fused forward:
|
||||||
|
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
|
||||||
|
y[..., :n_rot] = rope(x_norm[..., :n_rot])
|
||||||
|
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
|
||||||
|
|
||||||
|
rotate_half swaps first/second halves and negates the first, restricted
|
||||||
|
to the rotary span [0, n_rot):
|
||||||
|
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
|
||||||
|
|
||||||
|
For the partial-rotary pass-through region we load cos with default 1.0
|
||||||
|
and sin with default 0.0 outside [0, n_rot), so the same formula
|
||||||
|
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
|
||||||
|
|
||||||
|
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||||
|
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
|
||||||
|
cs_row_idx = row_idx // n_heads
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
rot_mask_col = col_offsets < n_rot
|
||||||
|
half_rot = n_rot // 2
|
||||||
|
|
||||||
|
# Load input row
|
||||||
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
X_dtype = X_row.dtype
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
|
||||||
|
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
|
||||||
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
|
rstd = rsqrt(mean_sq + eps)
|
||||||
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
X_norm = X_fp32 * rstd
|
||||||
|
|
||||||
|
# Apply weight if present (with_scale=True)
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||||
|
X_norm = X_norm * W_row
|
||||||
|
|
||||||
|
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
|
||||||
|
# cos=1, sin=0 so the formula leaves X_norm untouched.
|
||||||
|
cos_row = tl.load(
|
||||||
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||||
|
mask=rot_mask_col,
|
||||||
|
other=1.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
sin_row = tl.load(
|
||||||
|
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
|
||||||
|
mask=rot_mask_col,
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# rotate_half within [0, n_rot):
|
||||||
|
# for col < half_rot: take -X_norm[col + half_rot]
|
||||||
|
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
|
||||||
|
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
|
||||||
|
rot_offsets = tl.where(
|
||||||
|
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||||
|
)
|
||||||
|
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||||
|
X_rot = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
# Re-normalize the rotated values
|
||||||
|
X_rot_norm = X_rot * rstd
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
|
||||||
|
X_rot_norm = X_rot_norm * W_rot
|
||||||
|
|
||||||
|
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||||
|
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
|
||||||
|
X_rot_norm = X_rot_norm * sign
|
||||||
|
|
||||||
|
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||||
|
Y_row = X_norm * cos_row + X_rot_norm * sin_row
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||||
|
Y_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_rope_backward_kernel(
|
||||||
|
dY_ptr,
|
||||||
|
dY_row_stride,
|
||||||
|
dX_ptr,
|
||||||
|
dX_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
X_dtype: tl.constexpr,
|
||||||
|
W_ptr,
|
||||||
|
COS_ptr,
|
||||||
|
COS_row_stride,
|
||||||
|
SIN_ptr,
|
||||||
|
SIN_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
dW_ptr,
|
||||||
|
dW_row_stride,
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
n_rot,
|
||||||
|
n_heads,
|
||||||
|
rows_per_program,
|
||||||
|
HAS_WEIGHT: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
|
||||||
|
(`n_rot <= n_cols`).
|
||||||
|
|
||||||
|
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
|
||||||
|
output is just the normalized row, so dN[col] = dY[col] (achieved by
|
||||||
|
loading cos with default 1.0 and forcing the rotate-half contribution
|
||||||
|
to zero outside the rotary span).
|
||||||
|
|
||||||
|
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||||
|
"""
|
||||||
|
row_block_id = tl.program_id(0).to(tl.int64)
|
||||||
|
row_start = row_block_id * rows_per_program
|
||||||
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
rot_mask_col = col_offsets < n_rot
|
||||||
|
half_rot = n_rot // 2
|
||||||
|
|
||||||
|
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||||
|
|
||||||
|
for row_idx in range(row_start, row_end):
|
||||||
|
cs_row_idx = row_idx // n_heads
|
||||||
|
|
||||||
|
dY_row = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
X_row = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
|
cos_row = tl.load(
|
||||||
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||||
|
mask=rot_mask_col,
|
||||||
|
other=1.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
|
||||||
|
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||||
|
#
|
||||||
|
# For col >= n_rot the formula must collapse to dN = dY (since the
|
||||||
|
# forward is just a pass-through). cos defaults to 1.0 above; the
|
||||||
|
# rotate-half contribution is masked to zero below.
|
||||||
|
rot_offsets = tl.where(
|
||||||
|
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||||
|
)
|
||||||
|
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||||
|
dY_rot = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||||
|
mask=rot_load_mask,
|
||||||
|
other=0,
|
||||||
|
).to(tl.float32)
|
||||||
|
sin_rot = tl.load(
|
||||||
|
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||||
|
mask=rot_load_mask,
|
||||||
|
other=0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
|
||||||
|
rotate_term = dY_rot * sin_rot * adj_sign
|
||||||
|
# Zero out rotate-half contribution outside the rotary span.
|
||||||
|
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
|
||||||
|
dN = dY_row * cos_row + rotate_term
|
||||||
|
|
||||||
|
# Pre-weight normalized: n = rstd * x
|
||||||
|
n = X_row * rstd
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
dW_acc += dN * n
|
||||||
|
dm = dN * W_row
|
||||||
|
else:
|
||||||
|
dm = dN
|
||||||
|
|
||||||
|
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
|
||||||
|
dot_dm_x = tl.sum(dm * X_row, axis=0)
|
||||||
|
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||||
|
dX_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
tl.store(
|
||||||
|
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||||
|
dW_acc,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||||
|
W: (head_dim,) or None — RMSNorm weight
|
||||||
|
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||||
|
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||||
|
eps: float
|
||||||
|
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||||
|
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
|
||||||
|
partial rotary). Must be even and ``<= head_dim``.
|
||||||
|
Returns:
|
||||||
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
"""
|
||||||
|
n_rows, n_cols = X.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
has_weight = W is not None
|
||||||
|
|
||||||
|
Y = torch.empty_like(X)
|
||||||
|
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_rope_forward_kernel[(n_rows,)](
|
||||||
|
Y,
|
||||||
|
Y.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
W if has_weight else X, # dummy pointer when no weight
|
||||||
|
cos,
|
||||||
|
cos.stride(0),
|
||||||
|
sin,
|
||||||
|
sin.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
n_cols,
|
||||||
|
n_rot,
|
||||||
|
n_heads,
|
||||||
|
eps,
|
||||||
|
HAS_WEIGHT=has_weight,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_rope_backward(
|
||||||
|
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
|
||||||
|
):
|
||||||
|
n_rows, n_cols = dY.shape
|
||||||
|
has_weight = W is not None
|
||||||
|
|
||||||
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||||
|
rows_per_program = math.ceil(n_rows / sm_count)
|
||||||
|
|
||||||
|
dX = torch.empty_like(X)
|
||||||
|
|
||||||
|
if has_weight:
|
||||||
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
|
||||||
|
else:
|
||||||
|
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_rope_backward_kernel[(sm_count,)](
|
||||||
|
dY,
|
||||||
|
dY.stride(0),
|
||||||
|
dX,
|
||||||
|
dX.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
torch_to_triton_dtype[X.dtype],
|
||||||
|
W if has_weight else X, # dummy
|
||||||
|
cos,
|
||||||
|
cos.stride(0),
|
||||||
|
sin,
|
||||||
|
sin.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
_dW,
|
||||||
|
_dW.stride(0),
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
n_rot,
|
||||||
|
n_heads,
|
||||||
|
rows_per_program,
|
||||||
|
HAS_WEIGHT=has_weight,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
|
||||||
|
return dX, dW
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
|
||||||
|
"""
|
||||||
|
X: (B*S*H, head_dim)
|
||||||
|
W: (head_dim,) or None
|
||||||
|
cos: (B*S, n_rot) — broadcast across heads
|
||||||
|
sin: (B*S, n_rot) — broadcast across heads
|
||||||
|
n_heads: int
|
||||||
|
n_rot: int — rotary dim (<= head_dim)
|
||||||
|
"""
|
||||||
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps,
|
||||||
|
n_heads,
|
||||||
|
n_rot,
|
||||||
|
)
|
||||||
|
ctx.eps = eps
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.n_heads = n_heads
|
||||||
|
ctx.n_rot = n_rot
|
||||||
|
ctx.has_weight = W is not None
|
||||||
|
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def backward(ctx, dY):
|
||||||
|
X, W, cos, sin, RSTD = ctx.saved_tensors
|
||||||
|
dX, dW = rms_norm_rope_backward(
|
||||||
|
dY,
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
RSTD,
|
||||||
|
ctx.n_heads,
|
||||||
|
ctx.n_rot,
|
||||||
|
ctx.BLOCK_SIZE,
|
||||||
|
ctx.num_warps,
|
||||||
|
)
|
||||||
|
return dX, dW, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Apply fused RMSNorm + (partial) RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||||
|
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||||
|
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
|
||||||
|
must be even and ``<= head_dim``. When ``n_rot < head_dim``
|
||||||
|
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
|
||||||
|
(partial-rotary pass-through), matching stock Gemma 4 with
|
||||||
|
``partial_rotary_factor < 1.0``.
|
||||||
|
sin: (batch, seq_len, n_rot) — same shape as ``cos``
|
||||||
|
eps: float — RMSNorm epsilon
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
|
||||||
|
"""
|
||||||
|
shape = x.shape # (B, S, H, D)
|
||||||
|
B, S, H, D = shape
|
||||||
|
n_rot = cos.shape[-1]
|
||||||
|
if sin.shape[-1] != n_rot:
|
||||||
|
raise ValueError(
|
||||||
|
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
|
||||||
|
f"sin={sin.shape[-1]}"
|
||||||
|
)
|
||||||
|
if n_rot > D:
|
||||||
|
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
|
||||||
|
if n_rot % 2 != 0:
|
||||||
|
raise ValueError(f"rotary dim must be even, got {n_rot}")
|
||||||
|
|
||||||
|
# Flatten to 2D: (B*S*H, D)
|
||||||
|
x_flat = x.reshape(-1, D).contiguous()
|
||||||
|
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
|
||||||
|
# all sequences share the same rotary positions). The kernel needs a
|
||||||
|
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
|
||||||
|
# onto a single (b, s) pair, so expand-then-contiguous to materialize
|
||||||
|
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
|
||||||
|
if cos.shape[0] != B:
|
||||||
|
if cos.shape[0] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
|
||||||
|
f"to x batch dim ({B})"
|
||||||
|
)
|
||||||
|
cos = cos.expand(B, S, n_rot)
|
||||||
|
sin = sin.expand(B, S, n_rot)
|
||||||
|
cos_flat = cos.reshape(B * S, n_rot).contiguous()
|
||||||
|
sin_flat = sin.reshape(B * S, n_rot).contiguous()
|
||||||
|
|
||||||
|
y_flat = FusedRMSNormRoPEFunction.apply(
|
||||||
|
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
|
||||||
|
)
|
||||||
|
return y_flat.view(shape)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_forward_kernel(
|
||||||
|
Y_ptr,
|
||||||
|
Y_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""RMSNorm without scale weight: y = x / rms(x)"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
X_dtype = X_row.dtype
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
|
||||||
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
|
rstd = rsqrt(mean_sq + eps)
|
||||||
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
|
|
||||||
|
Y_row = X_fp32 * rstd
|
||||||
|
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_noscale_backward_kernel(
|
||||||
|
dY_ptr,
|
||||||
|
dY_row_stride,
|
||||||
|
dX_ptr,
|
||||||
|
dX_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
X_dtype: tl.constexpr,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Backward for y = x * rstd (no weight)."""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
dY_row = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
X_row = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
|
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
|
||||||
|
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
|
||||||
|
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def forward(ctx, X, eps):
|
||||||
|
n_rows, n_cols = X.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
Y = torch.empty_like(X)
|
||||||
|
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_forward_kernel[(n_rows,)](
|
||||||
|
Y,
|
||||||
|
Y.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.save_for_backward(X, RSTD)
|
||||||
|
ctx.n_cols = n_cols
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def backward(ctx, dY):
|
||||||
|
X, RSTD = ctx.saved_tensors
|
||||||
|
n_rows = X.shape[0]
|
||||||
|
dX = torch.empty_like(X)
|
||||||
|
_rms_norm_noscale_backward_kernel[(n_rows,)](
|
||||||
|
dY,
|
||||||
|
dY.stride(0),
|
||||||
|
dX,
|
||||||
|
dX.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
torch_to_triton_dtype[X.dtype],
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
ctx.n_cols,
|
||||||
|
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||||
|
num_warps=ctx.num_warps,
|
||||||
|
)
|
||||||
|
return dX, None
|
||||||
|
|
||||||
|
|
||||||
|
def fused_rms_norm_noscale(x, eps=1e-6):
|
||||||
|
"""
|
||||||
|
RMSNorm without scale for v_norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (batch, seq_len, num_heads, head_dim)
|
||||||
|
Returns:
|
||||||
|
y: same shape, normalized
|
||||||
|
"""
|
||||||
|
shape = x.shape
|
||||||
|
x_flat = x.reshape(-1, shape[-1])
|
||||||
|
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
|
||||||
|
return y_flat.view(shape)
|
||||||
@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
|
|||||||
return Q, K, V
|
return Q, K, V
|
||||||
|
|
||||||
|
|
||||||
|
class LoRA_QK(torch.autograd.Function):
|
||||||
|
"""Optimized LoRA QK implementation for models where v_proj is None.
|
||||||
|
|
||||||
|
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
|
||||||
|
reused as value states. Only Q and K projections are fused; the caller
|
||||||
|
returns K a second time as V so that autograd accumulates key+value gradients
|
||||||
|
into a single dK.
|
||||||
|
|
||||||
|
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_amp_custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
X: torch.Tensor,
|
||||||
|
X_drop: torch.Tensor | None,
|
||||||
|
# Q params
|
||||||
|
q_weight: torch.Tensor,
|
||||||
|
q_bias: torch.Tensor | None,
|
||||||
|
q_quant: QuantState | None,
|
||||||
|
q_A: torch.Tensor | None,
|
||||||
|
q_B: torch.Tensor | None,
|
||||||
|
q_scale: float,
|
||||||
|
q_lora_bias: torch.Tensor | None,
|
||||||
|
q_magnitude: torch.Tensor | None,
|
||||||
|
# K params
|
||||||
|
k_weight: torch.Tensor,
|
||||||
|
k_bias: torch.Tensor | None,
|
||||||
|
k_quant: QuantState | None,
|
||||||
|
k_A: torch.Tensor | None,
|
||||||
|
k_B: torch.Tensor | None,
|
||||||
|
k_scale: float,
|
||||||
|
k_lora_bias: torch.Tensor | None,
|
||||||
|
k_magnitude: torch.Tensor | None,
|
||||||
|
# Flags
|
||||||
|
inplace: bool = True,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
has_dropout = X_drop is not None
|
||||||
|
has_dora = q_magnitude is not None
|
||||||
|
|
||||||
|
if has_dora:
|
||||||
|
dtype = X.dtype
|
||||||
|
X_lora = X_drop if has_dropout else X
|
||||||
|
|
||||||
|
# Compute Q with DoRA
|
||||||
|
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
|
||||||
|
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
|
||||||
|
q_mag_scale = _compute_dora_scale(
|
||||||
|
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
|
||||||
|
)
|
||||||
|
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
|
||||||
|
if q_bias is not None:
|
||||||
|
Q = Q + q_bias
|
||||||
|
|
||||||
|
# Compute K with DoRA
|
||||||
|
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
|
||||||
|
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
|
||||||
|
k_mag_scale = _compute_dora_scale(
|
||||||
|
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
|
||||||
|
)
|
||||||
|
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
|
||||||
|
if k_bias is not None:
|
||||||
|
K = K + k_bias
|
||||||
|
|
||||||
|
Q_combined = Q_base + Q_lora
|
||||||
|
K_combined = K_base + K_lora
|
||||||
|
|
||||||
|
ctx.save_for_backward(
|
||||||
|
X,
|
||||||
|
X_drop if has_dropout else X,
|
||||||
|
q_A.to(dtype) if q_A is not None else q_A,
|
||||||
|
q_B.to(dtype) if q_B is not None else q_B,
|
||||||
|
k_A.to(dtype) if k_A is not None else k_A,
|
||||||
|
k_B.to(dtype) if k_B is not None else k_B,
|
||||||
|
q_magnitude,
|
||||||
|
k_magnitude,
|
||||||
|
q_mag_scale,
|
||||||
|
k_mag_scale,
|
||||||
|
Q_combined,
|
||||||
|
K_combined,
|
||||||
|
q_lora_bias,
|
||||||
|
k_lora_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard LoRA (with optional dropout and bias)
|
||||||
|
Q = matmul_lora(
|
||||||
|
X,
|
||||||
|
q_weight,
|
||||||
|
q_bias,
|
||||||
|
q_quant,
|
||||||
|
q_A,
|
||||||
|
q_B,
|
||||||
|
q_scale,
|
||||||
|
X_drop=X_drop,
|
||||||
|
lora_bias=q_lora_bias,
|
||||||
|
)
|
||||||
|
K = matmul_lora(
|
||||||
|
X,
|
||||||
|
k_weight,
|
||||||
|
k_bias,
|
||||||
|
k_quant,
|
||||||
|
k_A,
|
||||||
|
k_B,
|
||||||
|
k_scale,
|
||||||
|
X_drop=X_drop,
|
||||||
|
lora_bias=k_lora_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = X.dtype
|
||||||
|
ctx.save_for_backward(
|
||||||
|
X,
|
||||||
|
X_drop if has_dropout else X,
|
||||||
|
q_A.to(dtype) if q_A is not None else q_A,
|
||||||
|
q_B.to(dtype) if q_B is not None else q_B,
|
||||||
|
k_A.to(dtype) if k_A is not None else k_A,
|
||||||
|
k_B.to(dtype) if k_B is not None else k_B,
|
||||||
|
q_lora_bias,
|
||||||
|
k_lora_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.scales = (q_scale, k_scale)
|
||||||
|
ctx.quants = (q_quant, k_quant)
|
||||||
|
ctx.weights = (q_weight, k_weight)
|
||||||
|
ctx.inplace = inplace
|
||||||
|
ctx.has_dropout = has_dropout
|
||||||
|
ctx.has_dora = has_dora
|
||||||
|
|
||||||
|
return Q, K
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_amp_custom_bwd
|
||||||
|
def backward(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
q_grad: torch.Tensor,
|
||||||
|
k_grad: torch.Tensor,
|
||||||
|
):
|
||||||
|
q_weight, k_weight = ctx.weights
|
||||||
|
q_quant, k_quant = ctx.quants
|
||||||
|
q_scale, k_scale = ctx.scales
|
||||||
|
has_dropout = ctx.has_dropout
|
||||||
|
has_dora = ctx.has_dora
|
||||||
|
|
||||||
|
if has_dora:
|
||||||
|
(
|
||||||
|
X,
|
||||||
|
X_lora,
|
||||||
|
A_q,
|
||||||
|
B_q,
|
||||||
|
A_k,
|
||||||
|
B_k,
|
||||||
|
q_magnitude,
|
||||||
|
k_magnitude,
|
||||||
|
q_mag_scale,
|
||||||
|
k_mag_scale,
|
||||||
|
Q_combined,
|
||||||
|
K_combined,
|
||||||
|
q_lora_bias,
|
||||||
|
k_lora_bias,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
X,
|
||||||
|
X_lora,
|
||||||
|
A_q,
|
||||||
|
B_q,
|
||||||
|
A_k,
|
||||||
|
B_k,
|
||||||
|
q_lora_bias,
|
||||||
|
k_lora_bias,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
q_magnitude = k_magnitude = None
|
||||||
|
q_mag_scale = k_mag_scale = None
|
||||||
|
Q_combined = K_combined = None
|
||||||
|
|
||||||
|
batch, seq_len = X.shape[:2]
|
||||||
|
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||||
|
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||||
|
X = X.view(-1, X.shape[-1])
|
||||||
|
X_lora = X_lora.view(-1, X_lora.shape[-1])
|
||||||
|
|
||||||
|
d_q_mag = d_k_mag = None
|
||||||
|
d_q_lora_bias = d_k_lora_bias = None
|
||||||
|
|
||||||
|
if has_dora:
|
||||||
|
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
|
||||||
|
K_combined = K_combined.view(-1, K_combined.shape[-1])
|
||||||
|
|
||||||
|
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
|
||||||
|
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
|
||||||
|
|
||||||
|
q_grad = q_grad * q_mag_scale.unsqueeze(0)
|
||||||
|
k_grad = k_grad * k_mag_scale.unsqueeze(0)
|
||||||
|
|
||||||
|
# LoRA bias gradients
|
||||||
|
if q_lora_bias is not None:
|
||||||
|
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
|
||||||
|
if k_lora_bias is not None:
|
||||||
|
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
|
||||||
|
|
||||||
|
X_lora_t = X_lora.t()
|
||||||
|
|
||||||
|
d_A_q = d_B_q = d_A_k = d_B_k = None
|
||||||
|
grad_B_q = grad_B_k = None
|
||||||
|
|
||||||
|
if A_q is not None and B_q is not None:
|
||||||
|
grad_B_q = q_grad @ B_q
|
||||||
|
d_A_q = torch.empty_like(A_q.t())
|
||||||
|
d_B_q = torch.empty_like(B_q.t())
|
||||||
|
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
|
||||||
|
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
|
||||||
|
|
||||||
|
if A_k is not None and B_k is not None:
|
||||||
|
grad_B_k = k_grad @ B_k
|
||||||
|
d_A_k = torch.empty_like(A_k.t())
|
||||||
|
d_B_k = torch.empty_like(B_k.t())
|
||||||
|
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
|
||||||
|
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
|
||||||
|
|
||||||
|
# Base path input gradient
|
||||||
|
out_buffer = X if ctx.inplace else None
|
||||||
|
|
||||||
|
q_weight_t = dequantize(q_weight, q_quant)
|
||||||
|
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||||
|
del q_weight_t
|
||||||
|
|
||||||
|
k_weight_t = dequantize(k_weight, k_quant)
|
||||||
|
grad_X.addmm_(k_grad, k_weight_t)
|
||||||
|
del k_weight_t
|
||||||
|
|
||||||
|
# LoRA path input gradient
|
||||||
|
if has_dropout:
|
||||||
|
grad_X_drop = torch.zeros_like(X_lora)
|
||||||
|
if grad_B_q is not None:
|
||||||
|
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||||
|
if grad_B_k is not None:
|
||||||
|
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||||
|
else:
|
||||||
|
grad_X_drop = None
|
||||||
|
if grad_B_q is not None:
|
||||||
|
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||||
|
if grad_B_k is not None:
|
||||||
|
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||||
|
|
||||||
|
if d_A_q is not None:
|
||||||
|
d_A_q = d_A_q.t()
|
||||||
|
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||||
|
if d_A_k is not None:
|
||||||
|
d_A_k = d_A_k.t()
|
||||||
|
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||||
|
|
||||||
|
grad_X = grad_X.view(batch, seq_len, -1)
|
||||||
|
if grad_X_drop is not None:
|
||||||
|
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
|
||||||
|
|
||||||
|
# Return gradients for all forward inputs:
|
||||||
|
# X, X_drop,
|
||||||
|
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||||
|
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||||
|
# inplace
|
||||||
|
return (
|
||||||
|
grad_X,
|
||||||
|
grad_X_drop,
|
||||||
|
# Q
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
d_A_q,
|
||||||
|
d_B_q,
|
||||||
|
None,
|
||||||
|
d_q_lora_bias,
|
||||||
|
d_q_mag,
|
||||||
|
# K
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
d_A_k,
|
||||||
|
d_B_k,
|
||||||
|
None,
|
||||||
|
d_k_lora_bias,
|
||||||
|
d_k_mag,
|
||||||
|
# inplace
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lora_qk(
|
||||||
|
self, X: torch.Tensor, inplace: bool = True
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Applies LoRA to compute Query and Key projections for models where v_proj is None.
|
||||||
|
|
||||||
|
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
|
||||||
|
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
|
||||||
|
Because K is returned twice, autograd accumulates gradients from both the key and
|
||||||
|
value paths into dK before calling LoRA_QK.backward.
|
||||||
|
|
||||||
|
Supports bias, dropout, and DoRA.
|
||||||
|
"""
|
||||||
|
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
|
||||||
|
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
|
||||||
|
|
||||||
|
# Apply dropout outside autograd.Function (shared mask for Q, K)
|
||||||
|
X_drop = _apply_dropout(Qdrop, X, self.training)
|
||||||
|
|
||||||
|
Q, K = LoRA_QK.apply(
|
||||||
|
X,
|
||||||
|
X_drop,
|
||||||
|
# Q
|
||||||
|
QW,
|
||||||
|
Qb,
|
||||||
|
QW_quant,
|
||||||
|
QA,
|
||||||
|
QB,
|
||||||
|
QS,
|
||||||
|
Qlb,
|
||||||
|
Qmag,
|
||||||
|
# K
|
||||||
|
KW,
|
||||||
|
Kb,
|
||||||
|
KW_quant,
|
||||||
|
KA,
|
||||||
|
KB,
|
||||||
|
KS,
|
||||||
|
Klb,
|
||||||
|
Kmag,
|
||||||
|
# Flags
|
||||||
|
inplace,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Q, K, K
|
||||||
|
|
||||||
|
|
||||||
class LoRA_O(torch.autograd.Function):
|
class LoRA_O(torch.autograd.Function):
|
||||||
"""Optimized LoRA implementation for output projection.
|
"""Optimized LoRA implementation for output projection.
|
||||||
|
|
||||||
|
|||||||
@@ -67,12 +67,165 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_peft_clippable_linear():
|
||||||
|
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
|
||||||
|
|
||||||
|
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
|
||||||
|
that clips activations). PEFT doesn't recognise it as a supported layer type,
|
||||||
|
so we redirect LoRA injection to the inner ``.linear`` child instead.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4ClippableLinear as _cls,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
from peft.tuners.lora.model import LoraModel
|
||||||
|
|
||||||
|
if getattr(LoraModel, "_axolotl_clippable_patched", False):
|
||||||
|
return
|
||||||
|
_orig = LoraModel._create_and_replace
|
||||||
|
|
||||||
|
def _patched(
|
||||||
|
self,
|
||||||
|
peft_config,
|
||||||
|
adapter_name,
|
||||||
|
target,
|
||||||
|
target_name,
|
||||||
|
parent,
|
||||||
|
current_key=None,
|
||||||
|
**kw,
|
||||||
|
):
|
||||||
|
if isinstance(target, _cls):
|
||||||
|
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
|
||||||
|
return _orig(
|
||||||
|
self,
|
||||||
|
peft_config,
|
||||||
|
adapter_name,
|
||||||
|
target.linear,
|
||||||
|
"linear",
|
||||||
|
target,
|
||||||
|
current_key=current_key,
|
||||||
|
**kw,
|
||||||
|
)
|
||||||
|
return _orig(
|
||||||
|
self,
|
||||||
|
peft_config,
|
||||||
|
adapter_name,
|
||||||
|
target,
|
||||||
|
target_name,
|
||||||
|
parent,
|
||||||
|
current_key=current_key,
|
||||||
|
**kw,
|
||||||
|
)
|
||||||
|
|
||||||
|
LoraModel._create_and_replace = _patched
|
||||||
|
LoraModel._axolotl_clippable_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
|
||||||
|
"""Check whether PEFT will auto-populate target_parameters for this model.
|
||||||
|
|
||||||
|
PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE
|
||||||
|
``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into
|
||||||
|
``target_parameters`` (``gate_up_proj``/``down_proj``) because
|
||||||
|
transformers v5 fused those expert linears into 3D ``nn.Parameter``
|
||||||
|
tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``,
|
||||||
|
which rejects ``lora_dropout != 0``. This probe runs the conversion on
|
||||||
|
a copy of the config so we can detect the situation before
|
||||||
|
``get_peft_model`` blows up.
|
||||||
|
"""
|
||||||
|
if getattr(lora_config, "target_parameters", None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from peft.utils.transformers_weight_conversion import (
|
||||||
|
convert_peft_config_for_transformers,
|
||||||
|
get_model_conversion_mapping,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
probe_cfg = copy.deepcopy(lora_config)
|
||||||
|
try:
|
||||||
|
convert_peft_config_for_transformers(
|
||||||
|
probe_cfg,
|
||||||
|
model=model,
|
||||||
|
conversions=get_model_conversion_mapping(model),
|
||||||
|
)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
return False
|
||||||
|
|
||||||
|
return bool(getattr(probe_cfg, "target_parameters", None))
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_peft_param_wrapper_dropout():
|
||||||
|
"""Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``.
|
||||||
|
|
||||||
|
``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects
|
||||||
|
non-zero dropout because dropout can't be factored out of
|
||||||
|
``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed
|
||||||
|
matmul. For mixed configs (attention + MoE experts) this is too
|
||||||
|
aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout
|
||||||
|
and that's usually what the user intended. We pass a copy of the
|
||||||
|
``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__``
|
||||||
|
so it builds with ``nn.Identity`` for its internal dropout slot while
|
||||||
|
every other layer type still receives the real dropout value.
|
||||||
|
"""
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
|
|
||||||
|
if getattr(ParamWrapper, "_axolotl_dropout_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
_orig_init = ParamWrapper.__init__
|
||||||
|
|
||||||
|
def _patched_init(
|
||||||
|
self,
|
||||||
|
base_layer,
|
||||||
|
adapter_name,
|
||||||
|
parameter_name,
|
||||||
|
config,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if getattr(config, "lora_dropout", 0):
|
||||||
|
import copy as _copy
|
||||||
|
|
||||||
|
patched_config = _copy.copy(config)
|
||||||
|
patched_config.lora_dropout = 0.0
|
||||||
|
return _orig_init(
|
||||||
|
self,
|
||||||
|
base_layer,
|
||||||
|
adapter_name,
|
||||||
|
parameter_name,
|
||||||
|
patched_config,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return _orig_init(
|
||||||
|
self,
|
||||||
|
base_layer,
|
||||||
|
adapter_name,
|
||||||
|
parameter_name,
|
||||||
|
config,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
ParamWrapper.__init__ = _patched_init
|
||||||
|
ParamWrapper._axolotl_dropout_patched = True
|
||||||
|
|
||||||
|
|
||||||
def load_lora(
|
def load_lora(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
config_only: bool = False,
|
config_only: bool = False,
|
||||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||||
|
_patch_peft_clippable_linear()
|
||||||
lora_target_modules = cfg.lora_target_modules or []
|
lora_target_modules = cfg.lora_target_modules or []
|
||||||
lora_target_parameters = cfg.lora_target_parameters or []
|
lora_target_parameters = cfg.lora_target_parameters or []
|
||||||
|
|
||||||
@@ -124,6 +277,7 @@ def load_lora(
|
|||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||||
|
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
**lora_config_kwargs,
|
**lora_config_kwargs,
|
||||||
@@ -132,6 +286,20 @@ def load_lora(
|
|||||||
if config_only:
|
if config_only:
|
||||||
return None, lora_config
|
return None, lora_config
|
||||||
|
|
||||||
|
if getattr(
|
||||||
|
lora_config, "lora_dropout", 0
|
||||||
|
) and _peft_will_auto_convert_target_params(model, lora_config):
|
||||||
|
LOG.warning(
|
||||||
|
"lora_dropout=%s requested but PEFT will wrap this model's fused "
|
||||||
|
"MoE expert parameters with ParamWrapper, which cannot apply "
|
||||||
|
"dropout (the 3D einsum can't factor dropout out of "
|
||||||
|
"lora_B(lora_A(dropout(x)))). Dropout will still be applied to "
|
||||||
|
"non-expert LoRA layers (e.g. attention), and expert LoRA layers "
|
||||||
|
"will use nn.Identity for the dropout slot.",
|
||||||
|
lora_config.lora_dropout,
|
||||||
|
)
|
||||||
|
_patch_peft_param_wrapper_dropout()
|
||||||
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -547,6 +547,16 @@ class ModelLoader:
|
|||||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||||
|
|
||||||
|
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
|
||||||
|
from transformers import FineGrainedFP8Config
|
||||||
|
|
||||||
|
fp8_kwargs = {}
|
||||||
|
if self.cfg.model_quantization_config_kwargs:
|
||||||
|
fp8_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
|
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
|
||||||
|
**fp8_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -624,7 +634,14 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.attn_implementation:
|
if self.cfg.gemma4_hybrid_attn_impl:
|
||||||
|
# Load model with flash_attention_2 for sliding window layers;
|
||||||
|
# global layers will be patched to sdpa post-load.
|
||||||
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
self.model_config._attn_implementation = "flash_attention_2"
|
||||||
|
# Set flash_attention so multipack/sample_packing patches activate
|
||||||
|
self.cfg.flash_attention = True
|
||||||
|
elif self.cfg.attn_implementation:
|
||||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||||
elif self.cfg.flex_attention:
|
elif self.cfg.flex_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
|
|||||||
@@ -156,6 +156,15 @@ class PatchManager:
|
|||||||
# which would clobber any earlier fix.
|
# which would clobber any earlier fix.
|
||||||
self._fix_nemotron_h_conversion_mapping()
|
self._fix_nemotron_h_conversion_mapping()
|
||||||
|
|
||||||
|
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
|
||||||
|
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
|
||||||
|
# override needs to walk the raw model tree, which is broken by
|
||||||
|
# the post-load PEFT wrapping. The accompanying
|
||||||
|
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
|
||||||
|
# installation-time-independent, so both halves of the fix live
|
||||||
|
# cleanly in the same call even though one is instance-scoped
|
||||||
|
# and the other is module-scoped.
|
||||||
|
self._apply_gemma_hybrid_attention(model)
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
@@ -165,6 +174,83 @@ class PatchManager:
|
|||||||
self._apply_lora_kernel_patch(model)
|
self._apply_lora_kernel_patch(model)
|
||||||
self._apply_scaling_softmax_patch(model)
|
self._apply_scaling_softmax_patch(model)
|
||||||
|
|
||||||
|
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
|
||||||
|
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
|
||||||
|
|
||||||
|
Gemma 4 has global (full_attention) layers with head_dim=512
|
||||||
|
which exceeds flash attention's supported size. This patch loads the model
|
||||||
|
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||||
|
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||||
|
|
||||||
|
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
|
||||||
|
which fixes the corresponding mask construction inside
|
||||||
|
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
|
||||||
|
override is not enough — the forward still builds a 2D FA2-format mask
|
||||||
|
at the model level and the SDPA layers crash at long context lengths
|
||||||
|
with ``RuntimeError: The expanded size of the tensor ... must match``.
|
||||||
|
"""
|
||||||
|
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||||
|
return
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
|
||||||
|
# Navigate to the module that has 'layers' - varies by model structure:
|
||||||
|
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||||
|
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||||
|
layers = None
|
||||||
|
config_source = None
|
||||||
|
for candidate in [model, getattr(model, "model", None)]:
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
# Check direct layers
|
||||||
|
if hasattr(candidate, "layers"):
|
||||||
|
layers = candidate.layers
|
||||||
|
config_source = candidate
|
||||||
|
break
|
||||||
|
# Check language_model.layers (multimodal wrapper)
|
||||||
|
lang_model = getattr(candidate, "language_model", None)
|
||||||
|
if lang_model is not None and hasattr(lang_model, "layers"):
|
||||||
|
layers = lang_model.layers
|
||||||
|
config_source = lang_model
|
||||||
|
break
|
||||||
|
|
||||||
|
if layers is None:
|
||||||
|
LOG.warning(
|
||||||
|
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
config = getattr(config_source, "config", self.model_config)
|
||||||
|
layer_types = getattr(config, "layer_types", None)
|
||||||
|
if layer_types is None:
|
||||||
|
LOG.warning(
|
||||||
|
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
|
||||||
|
"This feature requires a model with mixed sliding/global attention layers."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
patched_count = 0
|
||||||
|
for layer_idx, layer in enumerate(layers):
|
||||||
|
if layer_types[layer_idx] != "sliding_attention":
|
||||||
|
# Global / full_attention layer - use SDPA instead of FA2
|
||||||
|
attn_module = getattr(layer, "self_attn", None)
|
||||||
|
if attn_module is not None and hasattr(attn_module, "config"):
|
||||||
|
sdpa_config = copy.copy(attn_module.config)
|
||||||
|
sdpa_config._attn_implementation = "sdpa"
|
||||||
|
attn_module.config = sdpa_config
|
||||||
|
patched_count += 1
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
|
||||||
|
"(remaining %d sliding layers use flash_attention_2)",
|
||||||
|
patched_count,
|
||||||
|
len(layers) - patched_count,
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_flash_attention_patches(self):
|
def _apply_flash_attention_patches(self):
|
||||||
"""Apply patches related to Flash Attention."""
|
"""Apply patches related to Flash Attention."""
|
||||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||||
@@ -324,6 +410,21 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_qwen3_5_vlm_flash_attention()
|
patch_qwen3_5_vlm_flash_attention()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||||
|
# The fused attn path is now compatible with
|
||||||
|
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
|
||||||
|
# rotary (cos.shape[-1] < head_dim) and the fused forward
|
||||||
|
# mirrors the current ``Gemma4TextAttention.forward`` API
|
||||||
|
# for shared kv (read from / write to
|
||||||
|
# ``past_key_values.shared_layers``). See
|
||||||
|
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
|
||||||
|
# for the history.
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_nemotron_h_conversion_mapping():
|
def _fix_nemotron_h_conversion_mapping():
|
||||||
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||||
|
|||||||
@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||||
|
|
||||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
|
||||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
LOG.warning(
|
|
||||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
|
||||||
tokenizer.eos_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
additional_special_tokens = None
|
additional_special_tokens = None
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
special_tokens = cfg.special_tokens.to_dict()
|
special_tokens = cfg.special_tokens.to_dict()
|
||||||
@@ -303,6 +295,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
{"additional_special_tokens": additional_special_tokens}
|
{"additional_special_tokens": additional_special_tokens}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||||
|
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
LOG.warning(
|
||||||
|
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||||
|
tokenizer.eos_token,
|
||||||
|
)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
|
|||||||
sharded_meta_param.placements,
|
sharded_meta_param.placements,
|
||||||
src_data_rank=0,
|
src_data_rank=0,
|
||||||
)
|
)
|
||||||
|
# Clone the local shard to allow full_tensor to be freed.
|
||||||
|
if (
|
||||||
|
sharded_param._local_tensor.untyped_storage().size()
|
||||||
|
> sharded_param._local_tensor.nelement()
|
||||||
|
* sharded_param._local_tensor.element_size()
|
||||||
|
):
|
||||||
|
sharded_param = sharded_param.clone()
|
||||||
else:
|
else:
|
||||||
# Non-sharded parameters
|
# Non-sharded parameters
|
||||||
if _accelerator.is_main_process:
|
if _accelerator.is_main_process:
|
||||||
|
|||||||
@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
|
|||||||
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# flash-attn-4>=4.0.0b7
|
||||||
|
from flash_attn.cute import flash_attn_with_kvcache
|
||||||
|
except ImportError:
|
||||||
|
flash_attn_with_kvcache = None
|
||||||
|
|
||||||
def _patched_lazy_imports(
|
def _patched_lazy_imports(
|
||||||
implementation, attention_wrapper=None, allow_all_kernels=False
|
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||||
):
|
):
|
||||||
return (
|
return (
|
||||||
flash_attn_func,
|
flash_attn_func,
|
||||||
flash_attn_varlen_func,
|
flash_attn_varlen_func,
|
||||||
|
flash_attn_with_kvcache,
|
||||||
fa_utils._pad_input,
|
fa_utils._pad_input,
|
||||||
fa_utils._unpad_input,
|
fa_utils._unpad_input,
|
||||||
)
|
)
|
||||||
|
|||||||
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""Hybrid attention mask fix for Gemma 4.
|
||||||
|
|
||||||
|
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
|
||||||
|
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
|
||||||
|
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
|
||||||
|
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
|
||||||
|
``self_attn.config``, leaving sliding-window layers on FA2.
|
||||||
|
|
||||||
|
The per-layer config override alone is insufficient, however:
|
||||||
|
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
|
||||||
|
using the **model-level** config and passes the mapped mask to each
|
||||||
|
decoder layer. With FA2 still set at the model level, the ``full_attention``
|
||||||
|
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
|
||||||
|
The global layers then fail with::
|
||||||
|
|
||||||
|
RuntimeError: The expanded size of the tensor (S) must match the existing
|
||||||
|
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
|
||||||
|
sizes: [B, S]
|
||||||
|
|
||||||
|
...when the sequence length grows past roughly 7k tokens.
|
||||||
|
|
||||||
|
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
|
||||||
|
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
|
||||||
|
the original in ``masking_utils``. The wrapper forces
|
||||||
|
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
|
||||||
|
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
|
||||||
|
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
|
||||||
|
alone, so sliding-window layers continue to receive FA2-format masks.
|
||||||
|
|
||||||
|
The patch is idempotent. Install once per process, before any Gemma 4
|
||||||
|
forward pass runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
_PATCH_APPLIED = False
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma4_hybrid_mask() -> bool:
|
||||||
|
"""Install the Gemma 4 hybrid-attention mask fix.
|
||||||
|
|
||||||
|
Returns ``True`` if the patch was installed (or was already installed),
|
||||||
|
``False`` if the target module could not be imported (e.g. transformers
|
||||||
|
version predates Gemma 4) — in which case nothing is done and the
|
||||||
|
caller can continue unaffected.
|
||||||
|
"""
|
||||||
|
global _PATCH_APPLIED
|
||||||
|
if _PATCH_APPLIED:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.models.gemma4 import modeling_gemma4
|
||||||
|
except ImportError:
|
||||||
|
LOG.debug(
|
||||||
|
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
|
||||||
|
"skipping. This is fine for non-Gemma4 training."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not hasattr(modeling_gemma4, "create_causal_mask"):
|
||||||
|
LOG.warning(
|
||||||
|
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
|
||||||
|
"binding, skipping. Transformers API may have changed."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
original = modeling_gemma4.create_causal_mask
|
||||||
|
|
||||||
|
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
|
||||||
|
"""Wrapper that forces SDPA format for the full-attention mask.
|
||||||
|
|
||||||
|
The global layers were patched to SDPA by
|
||||||
|
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
|
||||||
|
original ``create_causal_mask`` dispatches on
|
||||||
|
``config._attn_implementation``; we shadow that with a local
|
||||||
|
override.
|
||||||
|
"""
|
||||||
|
sdpa_config = copy.copy(config)
|
||||||
|
sdpa_config._attn_implementation = "sdpa"
|
||||||
|
return original(sdpa_config, *args, **kwargs)
|
||||||
|
|
||||||
|
# Preserve the original reference on the wrapper for tests / teardown.
|
||||||
|
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
|
||||||
|
_PATCH_APPLIED = True
|
||||||
|
LOG.info(
|
||||||
|
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
|
||||||
|
"force SDPA-format masks for full-attention layers"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_gemma4_hybrid_mask() -> None:
|
||||||
|
"""Restore the original ``create_causal_mask``. Useful for tests."""
|
||||||
|
global _PATCH_APPLIED
|
||||||
|
if not _PATCH_APPLIED:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from transformers.models.gemma4 import modeling_gemma4
|
||||||
|
except ImportError:
|
||||||
|
_PATCH_APPLIED = False
|
||||||
|
return
|
||||||
|
current = modeling_gemma4.create_causal_mask
|
||||||
|
original = getattr(current, "_axolotl_original", None)
|
||||||
|
if original is not None:
|
||||||
|
modeling_gemma4.create_causal_mask = original
|
||||||
|
_PATCH_APPLIED = False
|
||||||
@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
|
|||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
|
apply_lora_qk,
|
||||||
apply_lora_qkv,
|
apply_lora_qkv,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
@@ -111,6 +112,47 @@ QKV_PATCHES = [
|
|||||||
else:
|
else:
|
||||||
key_states = key_states.view(hidden_shape)
|
key_states = key_states.view(hidden_shape)
|
||||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
|
""".lstrip("\n"),
|
||||||
|
),
|
||||||
|
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
|
||||||
|
# past_key_values.shared_layers, and v_norm added after k_norm.
|
||||||
|
(
|
||||||
|
"""
|
||||||
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||||
|
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||||
|
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
# Device of past layer may be different from current one
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
|
""".lstrip("\n"),
|
||||||
|
"""
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||||
|
query_states = query_states.view(hidden_shape)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||||
|
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||||
|
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
# Device of past layer may be different from current one
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
key_states = key_states.view(hidden_shape)
|
||||||
|
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||||
""".lstrip("\n"),
|
""".lstrip("\n"),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -483,18 +525,24 @@ def apply_lora_kernel_patches(
|
|||||||
if cfg.lora_qkv_kernel:
|
if cfg.lora_qkv_kernel:
|
||||||
# Query, key, value patching
|
# Query, key, value patching
|
||||||
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
|
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
|
||||||
proj_names = ["q_proj", "k_proj", "v_proj"]
|
has_v_proj = getattr(self_attn, "v_proj", None) is not None
|
||||||
layer_modules = [
|
proj_names = (
|
||||||
getattr(self_attn, name)
|
["q_proj", "k_proj", "v_proj"]
|
||||||
for name in proj_names
|
if has_v_proj
|
||||||
if getattr(self_attn, name, None) is not None
|
else ["q_proj", "k_proj"]
|
||||||
]
|
)
|
||||||
|
layer_modules = [getattr(self_attn, name) for name in proj_names]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A") for module in layer_modules
|
hasattr(module, "lora_A") for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_qkv:
|
||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
if has_v_proj:
|
||||||
|
self_attn.apply_qkv = types.MethodType(
|
||||||
|
apply_lora_qkv, self_attn
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||||
|
|||||||
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Gemma 4 fused attention monkeypatch.
|
||||||
|
|
||||||
|
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
|
||||||
|
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fused_forward(original_forward):
|
||||||
|
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import (
|
||||||
|
fused_rms_norm_noscale,
|
||||||
|
fused_rms_norm_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fused_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
past_key_values=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
eager_attention_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
eps = self.config.rms_norm_eps
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
# ---- Projections ----
|
||||||
|
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
|
||||||
|
has_lora_qkv = hasattr(self, "apply_qkv")
|
||||||
|
|
||||||
|
if has_lora_qkv:
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||||
|
query_states = query_states.view(hidden_shape)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||||
|
|
||||||
|
# ---- Q path: fused q_norm + RoPE ----
|
||||||
|
query_states = fused_rms_norm_rope(
|
||||||
|
query_states,
|
||||||
|
self.q_norm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# ---- K/V path ----
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
if has_lora_qkv:
|
||||||
|
# apply_qkv already computed k/v projections
|
||||||
|
key_states = key_states.view(hidden_shape)
|
||||||
|
value_states = (
|
||||||
|
value_states.view(hidden_shape)
|
||||||
|
if self.v_proj is not None
|
||||||
|
else key_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states).view(hidden_shape)
|
||||||
|
if self.v_proj is not None
|
||||||
|
else key_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fused k_norm + RoPE
|
||||||
|
key_states = fused_rms_norm_rope(
|
||||||
|
key_states,
|
||||||
|
self.k_norm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# Fused v_norm (no scale, no RoPE)
|
||||||
|
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = past_key_values.update(
|
||||||
|
key_states, value_states, self.layer_idx
|
||||||
|
)
|
||||||
|
if self.store_full_length_kv:
|
||||||
|
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||||
|
self.config._attn_implementation
|
||||||
|
]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
dropout=self.attention_dropout if self.training else 0.0,
|
||||||
|
scaling=self.scaling,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
return fused_forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma4_fused_attn():
|
||||||
|
"""
|
||||||
|
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
||||||
|
"""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||||
|
|
||||||
|
original_forward = Gemma4TextAttention.forward
|
||||||
|
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||||
|
)
|
||||||
@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
||||||
|
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
||||||
|
# language-side module is separated from the vision tower. Try
|
||||||
|
# both names before giving up.
|
||||||
|
mlp_cls = getattr(
|
||||||
|
module,
|
||||||
|
f"{model_cls_prefix}MLP",
|
||||||
|
None,
|
||||||
|
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
||||||
|
|
||||||
if use_original_mlp:
|
if use_original_mlp:
|
||||||
mlp_forward = mlp_cls.forward
|
mlp_forward = mlp_cls.forward
|
||||||
|
|||||||
@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
self._validate_eot_and_eos_tokens()
|
self._validate_eot_and_eos_tokens()
|
||||||
|
|
||||||
|
# Pre-cache EOT token IDs to avoid re-encoding on every call
|
||||||
|
self._eot_token_ids = set()
|
||||||
|
for token in self.eot_tokens:
|
||||||
|
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||||
|
if len(token_ids) == 1:
|
||||||
|
self._eot_token_ids.add(token_ids[0])
|
||||||
|
|
||||||
def _validate_eot_and_eos_tokens(self):
|
def _validate_eot_and_eos_tokens(self):
|
||||||
"""
|
"""
|
||||||
- Validates that EOT tokens (or eos_token) are in the chat_template
|
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||||
@@ -471,6 +478,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
content = turn.get("content")
|
content = turn.get("content")
|
||||||
train_turn = turn.get("training")
|
train_turn = turn.get("training")
|
||||||
train_detail = turn.get("training_detail")
|
train_detail = turn.get("training_detail")
|
||||||
|
reasoning_train_detail = turn.get("reasoning_training_detail")
|
||||||
|
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||||
@@ -479,8 +487,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
should_train = None
|
should_train = None
|
||||||
if train_turn is not None:
|
if train_turn is not None:
|
||||||
should_train = train_turn
|
should_train = train_turn
|
||||||
elif train_detail is not None:
|
elif train_detail is not None or reasoning_train_detail is not None:
|
||||||
should_train = bool(train_detail)
|
should_train = bool(train_detail) or bool(reasoning_train_detail)
|
||||||
else:
|
else:
|
||||||
should_train = self.train_on_inputs or role in self.roles_to_train
|
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||||
|
|
||||||
@@ -500,15 +508,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
thinking_key = self.prompter.template_thinking_key
|
||||||
|
has_reasoning = thinking_key and turn.get(thinking_key) is not None
|
||||||
|
has_any_detail = train_detail or reasoning_train_detail
|
||||||
|
|
||||||
|
# When train_detail is present and the turn has reasoning_content,
|
||||||
|
# use content_only=True so find_turn returns content-only boundaries
|
||||||
|
# (excluding reasoning_content + template separator tokens).
|
||||||
|
use_content_only = bool(has_any_detail and has_reasoning)
|
||||||
|
|
||||||
turn_start_idx, turn_end_idx = self.find_turn(
|
turn_start_idx, turn_end_idx = self.find_turn(
|
||||||
turns=turns, turn_idx=index, tools=tools
|
turns=turns,
|
||||||
|
turn_idx=index,
|
||||||
|
tools=tools,
|
||||||
|
content_only=use_content_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
if train_detail:
|
if train_detail:
|
||||||
# Block multi-content for now
|
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`train_detail` is not supported when `content` is not a string."
|
"`train_detail` is not supported when `content` is not a string."
|
||||||
@@ -526,7 +545,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||||
)
|
)
|
||||||
else:
|
elif not reasoning_train_detail:
|
||||||
|
# No per-part detail on either field — train the whole span
|
||||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||||
turn_start_idx:turn_end_idx
|
turn_start_idx:turn_end_idx
|
||||||
]
|
]
|
||||||
@@ -534,6 +554,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle reasoning_content training_detail separately
|
||||||
|
if should_train and reasoning_train_detail and has_reasoning:
|
||||||
|
reasoning_text = turn[thinking_key]
|
||||||
|
if not isinstance(reasoning_text, str):
|
||||||
|
raise ValueError(
|
||||||
|
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_start, reasoning_end = self.find_turn(
|
||||||
|
turns=turns,
|
||||||
|
turn_idx=index,
|
||||||
|
tools=tools,
|
||||||
|
reasoning_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if reasoning_start != -1 and reasoning_end != -1:
|
||||||
|
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
||||||
|
reasoning_text, reasoning_train_detail
|
||||||
|
)
|
||||||
|
LOG.debug(f"Reasoning token offsets: {token_offsets}")
|
||||||
|
for i, offset in enumerate(token_offsets):
|
||||||
|
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
|
||||||
|
input_ids
|
||||||
|
):
|
||||||
|
labels[reasoning_start + i] = input_ids[reasoning_start + i]
|
||||||
|
|
||||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
# Handle special tokens (EOT and EOS)
|
# Handle special tokens (EOT and EOS)
|
||||||
@@ -593,28 +639,31 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def find_first_eot_token(self, input_ids, start_idx):
|
def find_first_eot_token(self, input_ids, start_idx):
|
||||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||||
# Get token IDs for all EOT tokens
|
# Use pre-cached EOT token IDs (computed once in __init__)
|
||||||
eot_token_ids = []
|
|
||||||
for token in self.eot_tokens:
|
|
||||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
|
||||||
if len(token_ids) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
|
||||||
)
|
|
||||||
|
|
||||||
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
|
||||||
|
|
||||||
# Search for any of the EOT token IDs
|
|
||||||
for i in range(start_idx, len(input_ids)):
|
for i in range(start_idx, len(input_ids)):
|
||||||
if input_ids[i] in eot_token_ids:
|
if input_ids[i] in self._eot_token_ids:
|
||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def find_turn(
|
def find_turn(
|
||||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
self,
|
||||||
|
turns: list[dict],
|
||||||
|
turn_idx: int,
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
content_only: bool = False,
|
||||||
|
reasoning_only: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Locate the starting and ending indices of the specified turn in a conversation.
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_only: If True and the turn has reasoning_content (template_thinking_key),
|
||||||
|
preserve reasoning_content in the dummy turn so the diff only captures the
|
||||||
|
content field boundaries. This is needed for correct training_detail alignment
|
||||||
|
when reasoning_content is present.
|
||||||
|
reasoning_only: If True, preserve content in the dummy turn and replace
|
||||||
|
reasoning_content with a dummy, so the diff only captures the
|
||||||
|
reasoning_content field boundaries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if turn_idx >= len(turns):
|
if turn_idx >= len(turns):
|
||||||
@@ -628,10 +677,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
):
|
):
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
empty_turn = {
|
thinking_key = self.prompter.template_thinking_key
|
||||||
"role": turns[turn_idx].get("role"),
|
|
||||||
"content": "[[dummy_message]]",
|
if reasoning_only:
|
||||||
}
|
# Keep content as-is, replace reasoning with dummy
|
||||||
|
empty_turn = {
|
||||||
|
"role": turns[turn_idx].get("role"),
|
||||||
|
"content": turns[turn_idx].get("content", ""),
|
||||||
|
}
|
||||||
|
if thinking_key and thinking_key in turns[turn_idx]:
|
||||||
|
empty_turn[thinking_key] = "[[dummy_reasoning]]"
|
||||||
|
else:
|
||||||
|
empty_turn = {
|
||||||
|
"role": turns[turn_idx].get("role"),
|
||||||
|
"content": "[[dummy_message]]",
|
||||||
|
}
|
||||||
|
|
||||||
|
# When content_only is True, copy reasoning_content to the dummy turn so
|
||||||
|
# the diff only captures the content field (not reasoning + separator).
|
||||||
|
if content_only and thinking_key and thinking_key in turns[turn_idx]:
|
||||||
|
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
|
||||||
|
|
||||||
# Create conversation versions
|
# Create conversation versions
|
||||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||||
@@ -697,6 +762,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return start_idx, end_idx
|
return start_idx, end_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_content_parts(
|
||||||
|
content,
|
||||||
|
) -> tuple[str, list[dict] | None] | None:
|
||||||
|
"""Convert list content to concatenated string + optional training_detail.
|
||||||
|
|
||||||
|
When content is a list of dicts (content parts), each part can specify:
|
||||||
|
- ``text``, ``content``, or ``value``: the text string
|
||||||
|
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
|
||||||
|
|
||||||
|
Returns ``(concatenated_text, training_details_or_None)`` if content was
|
||||||
|
a list, or ``None`` if content was not a list (no conversion needed).
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
**Whitespace at part boundaries matters.** BPE tokenizers prepend
|
||||||
|
spaces to word tokens (e.g. ``" answer"`` is one token). Always
|
||||||
|
split BEFORE spaces::
|
||||||
|
|
||||||
|
GOOD: ["Let me think...", " The answer is 4."]
|
||||||
|
BAD: ["Let me think... ", "The answer is 4."]
|
||||||
|
|
||||||
|
Tokens that straddle a boundary are conservatively masked.
|
||||||
|
Newlines typically merge with preceding punctuation (``":\\n"`` is
|
||||||
|
one token), so keep newlines with the preceding part.
|
||||||
|
"""
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
text_parts: list[str] = []
|
||||||
|
training_details: list[dict] = []
|
||||||
|
has_explicit_training = False
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict):
|
||||||
|
# Extract text (HF uses "text", also support "content"/"value")
|
||||||
|
text = (
|
||||||
|
part.get("text") or part.get("content") or part.get("value") or ""
|
||||||
|
)
|
||||||
|
text_parts.append(text)
|
||||||
|
|
||||||
|
# Check for per-part training flags
|
||||||
|
part_train = part.get("train")
|
||||||
|
part_weight = part.get("weight")
|
||||||
|
if part_train is not None or part_weight is not None:
|
||||||
|
has_explicit_training = True
|
||||||
|
train = (
|
||||||
|
part_train
|
||||||
|
if part_train is not None
|
||||||
|
else (part_weight not in (0, 0.0))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train = True # default trainable, gated by turn-level should_train
|
||||||
|
|
||||||
|
if text:
|
||||||
|
training_details.append(
|
||||||
|
{
|
||||||
|
"begin_offset": offset,
|
||||||
|
"end_offset": offset + len(text) - 1,
|
||||||
|
"train": train,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
offset += len(text)
|
||||||
|
|
||||||
|
# Warn about trailing whitespace at boundaries between parts with
|
||||||
|
# different training flags — this almost always causes token straddling
|
||||||
|
if has_explicit_training and len(training_details) > 1:
|
||||||
|
for i in range(len(training_details) - 1):
|
||||||
|
cur = training_details[i]
|
||||||
|
nxt = training_details[i + 1]
|
||||||
|
if cur["train"] != nxt["train"]:
|
||||||
|
boundary_text = text_parts[i]
|
||||||
|
if boundary_text and boundary_text[-1] in (" ", "\t"):
|
||||||
|
LOG.warning(
|
||||||
|
"Content part %d ends with whitespace at a train/mask boundary. "
|
||||||
|
"BPE tokenizers typically prepend spaces to word tokens, so "
|
||||||
|
"the space will merge with the next part's first word and the "
|
||||||
|
"resulting token will be MASKED (not trained). Move the "
|
||||||
|
"whitespace to the start of the next content part instead. "
|
||||||
|
"Part text: %r",
|
||||||
|
i,
|
||||||
|
boundary_text[-20:],
|
||||||
|
)
|
||||||
|
|
||||||
|
concatenated = "".join(text_parts)
|
||||||
|
details = training_details if has_explicit_training else None
|
||||||
|
return concatenated, details
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
turns = []
|
turns = []
|
||||||
|
|
||||||
@@ -723,6 +876,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if training_detail is not None:
|
if training_detail is not None:
|
||||||
turn["training_detail"] = training_detail
|
turn["training_detail"] = training_detail
|
||||||
|
|
||||||
|
# Convert list content/reasoning_content to string + auto-generated
|
||||||
|
# training_detail. See _convert_content_parts for whitespace guidance.
|
||||||
|
content_result = self._convert_content_parts(turn.get("content"))
|
||||||
|
if content_result is not None:
|
||||||
|
turn["content"] = content_result[0]
|
||||||
|
if content_result[1] is not None:
|
||||||
|
turn["training_detail"] = content_result[1]
|
||||||
|
|
||||||
|
# Also convert reasoning_content (template_thinking_key) if it's a list
|
||||||
|
thinking_key = self.prompter.template_thinking_key
|
||||||
|
if thinking_key and thinking_key in turn:
|
||||||
|
reasoning_result = self._convert_content_parts(turn[thinking_key])
|
||||||
|
if reasoning_result is not None:
|
||||||
|
turn[thinking_key] = reasoning_result[0]
|
||||||
|
if reasoning_result[1] is not None:
|
||||||
|
turn["reasoning_training_detail"] = reasoning_result[1]
|
||||||
|
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
||||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||||
|
|||||||
@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
|
|||||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||||
active_lora: dict = {"request": None}
|
active_lora: dict = {"request": None}
|
||||||
|
|
||||||
|
# Serializes access to the worker pipe. The underlying
|
||||||
|
# multiprocessing.Connection is a single full-duplex stream shared
|
||||||
|
# across all HTTP handlers; concurrent requests interleave bytes on
|
||||||
|
# the wire and corrupt the pickle framing (seen as
|
||||||
|
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
|
||||||
|
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
|
||||||
|
# the round-trip so only one inflight call at a time per pipe.
|
||||||
|
worker_pipe_lock = asyncio.Lock()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# LoRA-specific endpoints
|
# LoRA-specific endpoints
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -631,6 +640,150 @@ def main(script_args: ScriptArguments):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def openai_completions(request_body: dict):
|
||||||
|
"""OpenAI-compatible text-completions endpoint.
|
||||||
|
|
||||||
|
Accepts either a string ``prompt`` or a list-of-int
|
||||||
|
``prompt_token_ids`` (as the text-completions spec allows). Routes
|
||||||
|
to the internal vLLM generate method with the active LoRA adapter
|
||||||
|
and returns an OpenAI /v1/completions-shaped response including
|
||||||
|
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
|
||||||
|
``generation_log_probs`` for NeMo Gym agents that need raw
|
||||||
|
tokens + logprobs.
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
prompt_raw = request_body.get("prompt")
|
||||||
|
temperature = request_body.get("temperature", 1.0)
|
||||||
|
max_tokens = request_body.get("max_tokens", 512)
|
||||||
|
top_p = request_body.get("top_p", 1.0)
|
||||||
|
n = request_body.get("n", 1)
|
||||||
|
logprobs = request_body.get("logprobs") or 0
|
||||||
|
stop_token_ids = request_body.get("stop_token_ids") or None
|
||||||
|
|
||||||
|
# Accept either a string or a list[int] token id prompt. Lists
|
||||||
|
# must contain ints only (raise on lists of strings so callers get
|
||||||
|
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||||
|
# rare case callers pass a single-prompt batch.
|
||||||
|
if (
|
||||||
|
isinstance(prompt_raw, list)
|
||||||
|
and prompt_raw
|
||||||
|
and isinstance(prompt_raw[0], list)
|
||||||
|
):
|
||||||
|
prompt_raw = prompt_raw[0]
|
||||||
|
|
||||||
|
prompt_dict: dict[str, Any] = {}
|
||||||
|
if isinstance(prompt_raw, list):
|
||||||
|
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||||
|
elif isinstance(prompt_raw, str):
|
||||||
|
prompt_dict = {"prompt": prompt_raw}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": {
|
||||||
|
"message": ("prompt must be a string or a list of token ids"),
|
||||||
|
"type": "invalid_request",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
generation_kwargs: dict[str, Any] = {
|
||||||
|
"n": n,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
}
|
||||||
|
if stop_token_ids:
|
||||||
|
generation_kwargs["stop_token_ids"] = stop_token_ids
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
|
||||||
|
|
||||||
|
# Hold the pipe lock across send+recv — concurrent requests would
|
||||||
|
# otherwise interleave pickle frames on the worker connection.
|
||||||
|
async with worker_pipe_lock:
|
||||||
|
for conn, chunk in zip(connections, chunked, strict=True):
|
||||||
|
if not chunk:
|
||||||
|
chunk = [{"prompt": "<placeholder>"}]
|
||||||
|
kwargs = {
|
||||||
|
"prompts": chunk,
|
||||||
|
"sampling_params": sampling_params,
|
||||||
|
"lora_request": active_lora["request"],
|
||||||
|
}
|
||||||
|
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
all_outputs = await asyncio.gather(
|
||||||
|
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||||
|
for o in all_outputs:
|
||||||
|
if isinstance(o, dict) and "error" in o:
|
||||||
|
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||||
|
all_outputs = list(chain.from_iterable(all_outputs))
|
||||||
|
|
||||||
|
if not all_outputs:
|
||||||
|
return {"choices": [], "model": script_args.model}
|
||||||
|
|
||||||
|
choices = []
|
||||||
|
for i, output in enumerate(all_outputs):
|
||||||
|
for j, out in enumerate(output.outputs):
|
||||||
|
text = out.text
|
||||||
|
# OpenAI-style `logprobs` block for text-completions:
|
||||||
|
# { "tokens": [...], "token_logprobs": [...] }
|
||||||
|
lp_block = None
|
||||||
|
if out.logprobs:
|
||||||
|
tokens_str: list[str] = []
|
||||||
|
token_lps: list[float] = []
|
||||||
|
for step in out.logprobs:
|
||||||
|
chosen = next(iter(step.values()))
|
||||||
|
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
|
||||||
|
token_lps.append(float(chosen.logprob))
|
||||||
|
lp_block = {
|
||||||
|
"tokens": tokens_str,
|
||||||
|
"token_logprobs": token_lps,
|
||||||
|
}
|
||||||
|
|
||||||
|
choice = {
|
||||||
|
"index": i * n + j,
|
||||||
|
"text": text,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
if out.finish_reason == "stop"
|
||||||
|
else "length",
|
||||||
|
"logprobs": lp_block,
|
||||||
|
# NeMo-Gym / retrace agent extras — preserved on the
|
||||||
|
# choice so callers with raw-token pipelines don't
|
||||||
|
# have to re-tokenize.
|
||||||
|
"prompt_token_ids": output.prompt_token_ids,
|
||||||
|
"generation_token_ids": list(out.token_ids),
|
||||||
|
"generation_log_probs": (
|
||||||
|
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
|
||||||
|
if out.logprobs
|
||||||
|
else []
|
||||||
|
),
|
||||||
|
}
|
||||||
|
choices.append(choice)
|
||||||
|
|
||||||
|
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
||||||
|
completion_tokens = sum(
|
||||||
|
len(out.token_ids) for o in all_outputs for out in o.outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
|
||||||
|
"object": "text_completion",
|
||||||
|
"model": script_args.model,
|
||||||
|
"choices": choices,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||||
|
|
||||||
@app.post("/init_communicator/")
|
@app.post("/init_communicator/")
|
||||||
|
|||||||
@@ -160,29 +160,16 @@ class TelemetryManager:
|
|||||||
if not is_main_process():
|
if not is_main_process():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Parse relevant env vars
|
def is_truthy_env(var_name: str) -> bool:
|
||||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
value = os.getenv(var_name)
|
||||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
if value is None:
|
||||||
|
return False
|
||||||
|
return value.strip().lower() in ("1", "true")
|
||||||
|
|
||||||
# Default to enabled (opt-out model)
|
# Telemetry is enabled by default unless either opt-out var is set
|
||||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
return not (
|
||||||
"0",
|
is_truthy_env("AXOLOTL_DO_NOT_TRACK") or is_truthy_env("DO_NOT_TRACK")
|
||||||
"1",
|
)
|
||||||
"false",
|
|
||||||
"true",
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if do_not_track is None:
|
|
||||||
do_not_track = "0"
|
|
||||||
|
|
||||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
|
||||||
enabled = axolotl_do_not_track.lower() not in (
|
|
||||||
"1",
|
|
||||||
"true",
|
|
||||||
) and do_not_track.lower() not in ("1", "true")
|
|
||||||
|
|
||||||
return enabled
|
|
||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
def _load_whitelist(self) -> dict:
|
||||||
"""Load HuggingFace Hub organization whitelist"""
|
"""Load HuggingFace Hub organization whitelist"""
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
|
|||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
from axolotl.utils.train import determine_last_checkpoint
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
|
|||||||
):
|
):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
# Freeze multimodal modules for text-only training of multimodal models
|
||||||
|
if cfg.freeze_mm_modules:
|
||||||
|
freeze_mm_modules(model)
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
|
|
||||||
@@ -225,6 +229,28 @@ def execute_training(
|
|||||||
PLUGIN_MANAGER.post_train(cfg, trainer.model)
|
PLUGIN_MANAGER.post_train(cfg, trainer.model)
|
||||||
|
|
||||||
|
|
||||||
|
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
|
||||||
|
"""Rename model*.safetensors files to adapter_model* in place.
|
||||||
|
|
||||||
|
Also rewrites the index JSON weight_map if sharded output was produced.
|
||||||
|
"""
|
||||||
|
for file in sorted(merged_dir.iterdir()):
|
||||||
|
if file.name.startswith("model") and ".safetensors" in file.name:
|
||||||
|
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
|
||||||
|
|
||||||
|
index = merged_dir / "adapter_model.safetensors.index.json"
|
||||||
|
if index.exists():
|
||||||
|
data = json.loads(index.read_text(encoding="utf-8"))
|
||||||
|
if "weight_map" in data:
|
||||||
|
data["weight_map"] = {
|
||||||
|
k: v.replace("model", "adapter_model", 1)
|
||||||
|
for k, v in data["weight_map"].items()
|
||||||
|
}
|
||||||
|
index.write_text(
|
||||||
|
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_trained_model(
|
def save_trained_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
trainer: Any,
|
trainer: Any,
|
||||||
@@ -294,12 +320,17 @@ def save_trained_model(
|
|||||||
)
|
)
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
if trainer.accelerator.is_main_process:
|
if trainer.accelerator.is_main_process:
|
||||||
# move all files in merged_path to cfg.output_dir
|
# FSDP checkpoints for PEFT only contain adapter weights;
|
||||||
|
# rename model* → adapter_model* so it loads correctly.
|
||||||
|
is_peft = cfg.adapter and not cfg.relora
|
||||||
|
if is_peft:
|
||||||
|
_rename_fsdp_merged_to_adapter(Path(merged_path))
|
||||||
for merged_file in Path(merged_path).iterdir():
|
for merged_file in Path(merged_path).iterdir():
|
||||||
if (Path(cfg.output_dir) / merged_file.name).exists():
|
dest = Path(cfg.output_dir) / merged_file.name
|
||||||
(Path(cfg.output_dir) / merged_file.name).unlink()
|
if dest.exists():
|
||||||
shutil.move(str(merged_file), cfg.output_dir)
|
dest.unlink()
|
||||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
shutil.move(str(merged_file), dest)
|
||||||
|
shutil.rmtree(merged_path)
|
||||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||||
# cleanup the FSDP prefix in the model config.json
|
# cleanup the FSDP prefix in the model config.json
|
||||||
if trainer.accelerator.is_main_process:
|
if trainer.accelerator.is_main_process:
|
||||||
|
|||||||
@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SkipEvalOnResumeCallback(TrainerCallback):
|
||||||
|
"""Skip the redundant evaluation that fires when resuming from a checkpoint
|
||||||
|
whose step aligns with ``eval_steps``.
|
||||||
|
|
||||||
|
When HuggingFace Trainer resumes, it restores ``global_step`` from the
|
||||||
|
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
|
||||||
|
step. Because the evaluation was already performed during the original
|
||||||
|
run, repeating it wastes time and pollutes metric logs.
|
||||||
|
|
||||||
|
This callback records the ``global_step`` at the start of training (i.e.
|
||||||
|
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
|
||||||
|
any evaluation request on that exact step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._resume_step: int | None = None
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
# ``global_step`` is already restored from the checkpoint at this
|
||||||
|
# point. For a fresh run it will be 0, so the guard below becomes a
|
||||||
|
# no-op.
|
||||||
|
self._resume_step = state.global_step
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
) -> TrainerControl:
|
||||||
|
if (
|
||||||
|
self._resume_step
|
||||||
|
and state.global_step <= self._resume_step
|
||||||
|
and control.should_evaluate
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Skipping evaluation at step %d (already completed before resume)",
|
||||||
|
state.global_step,
|
||||||
|
)
|
||||||
|
control.should_evaluate = False
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
|
|||||||
@@ -1,7 +1,19 @@
|
|||||||
{%- if tools %}
|
{%- if tools %}
|
||||||
{{- '<|im_start|>system\n' }}
|
{{- '<|im_start|>system\n' }}
|
||||||
{%- if messages[0].role == 'system' %}
|
{%- if messages[0].role == 'system' %}
|
||||||
{{- messages[0].content + '\n\n' }}
|
{%- if messages[0].content is string %}
|
||||||
|
{{- messages[0].content + '\n\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{%- for part in messages[0].content %}
|
||||||
|
{%- if part is mapping %}
|
||||||
|
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
|
||||||
|
{%- if system_text %}{{- system_text }}{%- endif %}
|
||||||
|
{%- elif part is string %}
|
||||||
|
{{- part }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||||
{%- for tool in tools %}
|
{%- for tool in tools %}
|
||||||
@@ -11,7 +23,20 @@
|
|||||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||||
{%- else %}
|
{%- else %}
|
||||||
{%- if messages[0].role == 'system' %}
|
{%- if messages[0].role == 'system' %}
|
||||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
{%- if messages[0].content is string %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{%- for part in messages[0].content %}
|
||||||
|
{%- if part is mapping %}
|
||||||
|
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
|
||||||
|
{%- if system_text %}{{- system_text }}{%- endif %}
|
||||||
|
{%- elif part is string %}
|
||||||
|
{{- part }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
|
|||||||
@@ -268,6 +268,37 @@ def normalize_config(cfg):
|
|||||||
):
|
):
|
||||||
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
|
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
|
||||||
|
# "marked ready twice" errors with reentrant checkpointing) and
|
||||||
|
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
|
||||||
|
# receive gradients on every step).
|
||||||
|
if cfg.model_config_type == "gemma4":
|
||||||
|
if cfg.gradient_checkpointing:
|
||||||
|
if cfg.gradient_checkpointing_kwargs is None:
|
||||||
|
cfg.gradient_checkpointing_kwargs = {}
|
||||||
|
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
|
||||||
|
LOG.warning(
|
||||||
|
"Gemma4 requires use_reentrant=False for gradient checkpointing "
|
||||||
|
"in distributed training. Setting use_reentrant=False."
|
||||||
|
)
|
||||||
|
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
|
||||||
|
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
|
||||||
|
if cfg.activation_offloading is True:
|
||||||
|
# activation_offloading uses checkpoint wrappers that conflict
|
||||||
|
# with find_unused_parameters (causes "marked ready twice").
|
||||||
|
# Use freeze_mm_modules instead to eliminate unused params.
|
||||||
|
LOG.info(
|
||||||
|
"Gemma4 + DDP + activation_offloading: skipping "
|
||||||
|
"ddp_find_unused_parameters (use freeze_mm_modules to "
|
||||||
|
"handle unused vision/audio params)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
|
||||||
|
"Auto-enabling."
|
||||||
|
)
|
||||||
|
cfg.ddp_find_unused_parameters = True
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -180,6 +180,119 @@ def _drop_long_sequences(
|
|||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_on_long_sequences(
|
||||||
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
|
) -> bool:
|
||||||
|
"""Check sequence length and raise ValueError if exceeded.
|
||||||
|
|
||||||
|
Used as a filter function for ``excess_length_strategy: raise``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Dataset sample to check.
|
||||||
|
rl: Reinforcement learning type.
|
||||||
|
tokenizer: Tokenizer for length calculation.
|
||||||
|
sequence_len: Maximum allowed sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always True (raises before returning False).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any sample exceeds the configured sequence length.
|
||||||
|
"""
|
||||||
|
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sample exceeds configured sequence_len ({sequence_len}). "
|
||||||
|
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
|
||||||
|
"to handle long sequences automatically."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_long_sequences_rl(
|
||||||
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Truncate RL samples that exceed maximum sequence length.
|
||||||
|
|
||||||
|
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
|
||||||
|
responses to fit within ``sequence_len`` when combined with the prompt.
|
||||||
|
For KTO, truncates the completion similarly.
|
||||||
|
GRPO/GDPO/EBFT samples are returned unchanged.
|
||||||
|
|
||||||
|
Samples where the prompt alone exceeds ``sequence_len`` cannot be
|
||||||
|
meaningfully truncated and are returned unchanged. The caller should
|
||||||
|
follow up with a drop filter to remove them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Dataset sample to potentially truncate.
|
||||||
|
rl: Reinforcement learning type.
|
||||||
|
tokenizer: Tokenizer for encoding/decoding.
|
||||||
|
sequence_len: Maximum allowed sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sample with text fields truncated to fit within sequence_len.
|
||||||
|
"""
|
||||||
|
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
|
||||||
|
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
|
||||||
|
return sample
|
||||||
|
|
||||||
|
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||||
|
if not (
|
||||||
|
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
|
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_response_len = sequence_len - len(prompt_ids)
|
||||||
|
if max_response_len <= 0:
|
||||||
|
# Prompt alone exceeds limit; cannot meaningfully truncate.
|
||||||
|
# Returned unchanged — the follow-up drop filter will remove it.
|
||||||
|
return sample
|
||||||
|
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if len(chosen_ids) > max_response_len:
|
||||||
|
updates["chosen"] = tokenizer.decode(
|
||||||
|
chosen_ids[:max_response_len], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
if len(rejected_ids) > max_response_len:
|
||||||
|
updates["rejected"] = tokenizer.decode(
|
||||||
|
rejected_ids[:max_response_len], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
if updates:
|
||||||
|
sample = {**sample, **updates}
|
||||||
|
|
||||||
|
elif rl is RLType.KTO:
|
||||||
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
|
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
|
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_completion_len = sequence_len - len(prompt_ids)
|
||||||
|
if max_completion_len <= 0:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
if len(completion_ids) > max_completion_len:
|
||||||
|
sample = {
|
||||||
|
**sample,
|
||||||
|
"completion": tokenizer.decode(
|
||||||
|
completion_ids[:max_completion_len], skip_special_tokens=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||||
"""Load and process dataset split for RL training.
|
"""Load and process dataset split for RL training.
|
||||||
|
|
||||||
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
split_datasets[i] = dataset
|
split_datasets[i] = dataset
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
drop_long = partial(
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
_drop_long_sequences,
|
|
||||||
rl=cfg.rl,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
if excess_length_strategy == "truncate":
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
truncate_fn = partial(
|
||||||
drop_long,
|
_truncate_long_sequences_rl,
|
||||||
num_proc=cfg.dataset_num_proc,
|
rl=cfg.rl,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
tokenizer=tokenizer,
|
||||||
desc="Dropping Long Sequences",
|
sequence_len=cfg.sequence_len,
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
if dropped:
|
split_datasets[i] = split_datasets[i].map(
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
truncate_fn,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Truncating Long Sequences",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop samples that could not be truncated (e.g. prompt
|
||||||
|
# alone exceeds sequence_len)
|
||||||
|
drop_long = partial(
|
||||||
|
_drop_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Un-truncatable Sequences",
|
||||||
|
)
|
||||||
|
dropped = prior_len - len(split_datasets[i])
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} samples from dataset index {i} "
|
||||||
|
f"that could not be truncated to fit sequence_len "
|
||||||
|
f"(prompt alone exceeds limit)"
|
||||||
|
)
|
||||||
|
elif excess_length_strategy == "raise":
|
||||||
|
raise_fn = partial(
|
||||||
|
_raise_on_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
raise_fn,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Checking Sequence Lengths",
|
||||||
|
)
|
||||||
|
else: # "drop" (default)
|
||||||
|
drop_long = partial(
|
||||||
|
_drop_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
prior_len = len(split_datasets[i])
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Long Sequences",
|
||||||
|
)
|
||||||
|
dropped = prior_len - len(split_datasets[i])
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
|
||||||
# Merge datasets
|
# Merge datasets
|
||||||
dataset = merge_datasets(split_datasets, cfg)
|
dataset = merge_datasets(split_datasets, cfg)
|
||||||
|
|||||||
@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
|
||||||
|
# rather than the language backbone. These are matched against the first component
|
||||||
|
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
|
||||||
|
_MM_MODULE_PREFIXES = (
|
||||||
|
"vision_tower",
|
||||||
|
"vision_model",
|
||||||
|
"vision_encoder",
|
||||||
|
"embed_vision",
|
||||||
|
"multi_modal_projector",
|
||||||
|
"visual",
|
||||||
|
"audio_tower",
|
||||||
|
"audio_model",
|
||||||
|
"embed_audio",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_mm_modules(model):
|
||||||
|
"""Freeze all vision/audio/multimodal-projector parameters.
|
||||||
|
|
||||||
|
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
|
||||||
|
for any parameter whose name contains a known vision/audio module prefix.
|
||||||
|
This is useful when fine-tuning only the language backbone of a multimodal
|
||||||
|
model and avoids the need for ``ddp_find_unused_parameters=True``.
|
||||||
|
"""
|
||||||
|
frozen_count = 0
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
# Check if any path component matches a vision/audio prefix
|
||||||
|
parts = name.split(".")
|
||||||
|
if any(part in _MM_MODULE_PREFIXES for part in parts):
|
||||||
|
if param.requires_grad:
|
||||||
|
param.requires_grad = False
|
||||||
|
frozen_count += 1
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"freeze_mm_modules: froze {name}")
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
|
||||||
|
|
||||||
|
|
||||||
def freeze_layers_except(model, regex_patterns):
|
def freeze_layers_except(model, regex_patterns):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -578,6 +578,17 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
freeze_mm_modules: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
|
||||||
|
"text-only training of multimodal models. When True, parameters belonging to "
|
||||||
|
"vision towers, audio towers, multimodal projectors, and similar non-language "
|
||||||
|
"modules are frozen (requires_grad=False). This allows DDP training without "
|
||||||
|
"ddp_find_unused_parameters=True."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = Field(
|
unfrozen_parameters: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -766,6 +777,15 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gemma4_hybrid_attn_impl: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
|
||||||
|
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
|
||||||
|
"exceeds flash attention's supported size."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
experts_implementation: str | None = Field(
|
experts_implementation: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel):
|
|||||||
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
||||||
)
|
)
|
||||||
|
|
||||||
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = (
|
||||||
default=None,
|
Field(
|
||||||
json_schema_extra={"description": "Model loading quantization config"},
|
default=None,
|
||||||
|
json_schema_extra={"description": "Model loading quantization config"},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -770,6 +770,88 @@ class RLValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_grpo_batch_size_divisibility(cls, data):
|
||||||
|
"""Surface GRPO batch-shape mismatches at config-parse time.
|
||||||
|
|
||||||
|
TRL's GRPOTrainer requires that the per-step generation batch size be
|
||||||
|
evenly divisible by ``num_generations`` so that every prompt can be
|
||||||
|
replicated exactly ``num_generations`` times. The runtime check inside
|
||||||
|
``GRPOTrainer.__init__`` only fires after the model has been loaded —
|
||||||
|
too late and too cryptic for the user. We replicate the check here so
|
||||||
|
the failure is immediate and actionable.
|
||||||
|
|
||||||
|
Also enforces:
|
||||||
|
- ``num_generations >= 2`` (group-relative advantage needs variance)
|
||||||
|
- ``effective_gbs >= num_generations * world_size`` when capabilities
|
||||||
|
indicate multiple ranks (each rank needs at least one full group)
|
||||||
|
"""
|
||||||
|
if data.get("rl") != "grpo":
|
||||||
|
return data
|
||||||
|
|
||||||
|
trl_cfg = data.get("trl") or {}
|
||||||
|
num_gen = trl_cfg.get("num_generations")
|
||||||
|
if num_gen is None:
|
||||||
|
# TRL's own default is 8 — but if the user didn't set it, we
|
||||||
|
# don't have enough info to validate anything. Let TRL's own
|
||||||
|
# init handle the default-vs-batch interaction.
|
||||||
|
return data
|
||||||
|
if num_gen < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
|
||||||
|
"With num_generations=1, every group has zero advantage and "
|
||||||
|
"the policy never updates."
|
||||||
|
)
|
||||||
|
|
||||||
|
explicit_gbs = trl_cfg.get("generation_batch_size")
|
||||||
|
if explicit_gbs is not None:
|
||||||
|
effective_gbs = int(explicit_gbs)
|
||||||
|
gbs_source = "trl.generation_batch_size"
|
||||||
|
else:
|
||||||
|
mb = data.get("micro_batch_size") or 1
|
||||||
|
ga = data.get("gradient_accumulation_steps") or 1
|
||||||
|
effective_gbs = int(mb) * int(ga)
|
||||||
|
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
|
||||||
|
|
||||||
|
if effective_gbs % num_gen != 0:
|
||||||
|
# Suggest the smallest GA bump that fixes it for the common case
|
||||||
|
# where the user hasn't set generation_batch_size explicitly.
|
||||||
|
hint = ""
|
||||||
|
if explicit_gbs is None:
|
||||||
|
from math import gcd
|
||||||
|
|
||||||
|
mb_val = int(data.get("micro_batch_size") or 1)
|
||||||
|
# smallest GA such that mb*GA is a multiple of num_gen
|
||||||
|
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
|
||||||
|
suggested_ga = lcm // mb_val
|
||||||
|
hint = (
|
||||||
|
f" Smallest fix: set `gradient_accumulation_steps: "
|
||||||
|
f"{suggested_ga}` (so micro_batch_size * GA = "
|
||||||
|
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"GRPO: generation batch size must be divisible by "
|
||||||
|
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
|
||||||
|
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-rank check: each rank must receive at least one full group
|
||||||
|
# per step. Without `capabilities` populated yet (mode='before'), we
|
||||||
|
# fall back to user-set distributed fields.
|
||||||
|
world_size = (
|
||||||
|
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
|
||||||
|
)
|
||||||
|
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"GRPO with world_size={world_size} requires effective_gbs "
|
||||||
|
f">= num_generations * world_size = {num_gen * world_size}, "
|
||||||
|
f"got {effective_gbs}. Increase gradient_accumulation_steps "
|
||||||
|
f"or micro_batch_size."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class OptimizationValidationMixin:
|
class OptimizationValidationMixin:
|
||||||
"""Validation methods related to optimization and performance."""
|
"""Validation methods related to optimization and performance."""
|
||||||
|
|||||||
@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
|||||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVllmLoraSyncPatch(unittest.TestCase):
|
||||||
|
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
||||||
|
|
||||||
|
These tests exercise the patch-installation branch in isolation. They build
|
||||||
|
a stub trainer with just enough attributes to look like
|
||||||
|
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
||||||
|
|
||||||
|
Background — there are two correct behaviors and we historically had a bug
|
||||||
|
where both modes used the same one:
|
||||||
|
|
||||||
|
- Async prefetch ON: the BG generation thread can't safely call
|
||||||
|
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
||||||
|
the main thread via ``_maybe_sync_vllm_weights``.
|
||||||
|
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
||||||
|
calls ``sync_weights`` once per step boundary on the main thread. We
|
||||||
|
wire that hook directly to ``_sync_lora_adapter`` because
|
||||||
|
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
||||||
|
|
||||||
|
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
||||||
|
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
||||||
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeArgs:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = FakeArgs()
|
||||||
|
args.vllm_lora_sync = vllm_lora_sync
|
||||||
|
args.async_prefetch = async_prefetch
|
||||||
|
|
||||||
|
class FakeVllmGen:
|
||||||
|
sync_weights = staticmethod(lambda: None)
|
||||||
|
model = MagicMock()
|
||||||
|
|
||||||
|
# Use object.__new__ so we don't run __init__ (which needs a real
|
||||||
|
# model, dataset, etc.). We only need the `_generate_single_turn`
|
||||||
|
# method's patch branch to run, so we set up the minimum state.
|
||||||
|
trainer = object.__new__(AsyncGRPOTrainer)
|
||||||
|
trainer.args = args
|
||||||
|
trainer.use_vllm = True
|
||||||
|
trainer.vllm_generation = FakeVllmGen()
|
||||||
|
trainer._patched_sync_weights = False
|
||||||
|
# Spy on _sync_lora_adapter so we can assert it's the function the
|
||||||
|
# hook delegates to in sync mode.
|
||||||
|
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
||||||
|
trainer._sync_peft_weights_no_merge = MagicMock(
|
||||||
|
name="_sync_peft_weights_no_merge_spy"
|
||||||
|
)
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _run_patch_branch(trainer):
|
||||||
|
"""Execute just the sync_weights-patching branch in isolation.
|
||||||
|
|
||||||
|
We can't easily call the real ``_generate_single_turn`` because it
|
||||||
|
does a full vLLM generate. Instead we copy the exact branch out of
|
||||||
|
the source so the test verifies the same logic the trainer runs.
|
||||||
|
"""
|
||||||
|
if not getattr(trainer, "_patched_sync_weights", False):
|
||||||
|
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
||||||
|
if getattr(trainer.args, "vllm_lora_sync", False):
|
||||||
|
if getattr(trainer.args, "async_prefetch", False):
|
||||||
|
trainer.vllm_generation.sync_weights = lambda: None
|
||||||
|
else:
|
||||||
|
sync_helper = trainer._sync_lora_adapter
|
||||||
|
|
||||||
|
def _lora_filesystem_sync():
|
||||||
|
sync_helper()
|
||||||
|
|
||||||
|
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||||
|
trainer._patched_sync_weights = True
|
||||||
|
|
||||||
|
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
||||||
|
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||||
|
self._run_patch_branch(trainer)
|
||||||
|
|
||||||
|
assert trainer._patched_sync_weights is True
|
||||||
|
# Trigger the patched hook — it must call _sync_lora_adapter.
|
||||||
|
trainer.vllm_generation.sync_weights()
|
||||||
|
trainer._sync_lora_adapter.assert_called_once()
|
||||||
|
|
||||||
|
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
||||||
|
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
||||||
|
self._run_patch_branch(trainer)
|
||||||
|
|
||||||
|
assert trainer._patched_sync_weights is True
|
||||||
|
# Hook must be a no-op so BG-thread generation doesn't fight the
|
||||||
|
# main-thread optimizer step over the model weights.
|
||||||
|
trainer.vllm_generation.sync_weights()
|
||||||
|
trainer._sync_lora_adapter.assert_not_called()
|
||||||
|
|
||||||
|
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
||||||
|
"""Installing the patch should not pre-emptively sync."""
|
||||||
|
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||||
|
self._run_patch_branch(trainer)
|
||||||
|
# _sync_lora_adapter should only be called when the patched hook
|
||||||
|
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
||||||
|
trainer._sync_lora_adapter.assert_not_called()
|
||||||
|
|
||||||
|
def test_patch_is_idempotent(self):
|
||||||
|
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||||
|
self._run_patch_branch(trainer)
|
||||||
|
first_hook = trainer.vllm_generation.sync_weights
|
||||||
|
# Second call must not re-patch (otherwise we'd lose the original).
|
||||||
|
self._run_patch_branch(trainer)
|
||||||
|
assert trainer.vllm_generation.sync_weights is first_hook
|
||||||
|
|
||||||
|
|
||||||
|
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
||||||
|
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
||||||
|
|
||||||
|
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
||||||
|
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
||||||
|
is the default for any config that doesn't explicitly set it). We now
|
||||||
|
fall back to interval=1 so unset means "sync every step", matching the
|
||||||
|
behavior of TRL's own ``_generate_single_turn``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_stub_trainer(interval, async_prefetch):
|
||||||
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeArgs:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args = FakeArgs()
|
||||||
|
args.async_prefetch = async_prefetch
|
||||||
|
args.vllm_sync_interval = interval
|
||||||
|
args.vllm_lora_sync = True
|
||||||
|
|
||||||
|
class FakeState:
|
||||||
|
global_step = 1
|
||||||
|
|
||||||
|
trainer = object.__new__(AsyncGRPOTrainer)
|
||||||
|
trainer.args = args
|
||||||
|
trainer.use_vllm = True
|
||||||
|
trainer.state = FakeState()
|
||||||
|
trainer._last_synced_step = 0
|
||||||
|
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
def test_interval_none_in_async_mode_does_not_crash(self):
|
||||||
|
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
||||||
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise TypeError — defaults to every-step sync
|
||||||
|
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||||
|
trainer._sync_lora_adapter.assert_called_once()
|
||||||
|
|
||||||
|
def test_sync_mode_drives_sync(self):
|
||||||
|
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
||||||
|
|
||||||
|
The previous behavior (early return when ``not async_prefetch``)
|
||||||
|
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
||||||
|
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
||||||
|
where the data producer bypasses ``_generate_single_turn``
|
||||||
|
entirely. Without this trigger no sync ever happens and the
|
||||||
|
trainer becomes a no-op.
|
||||||
|
"""
|
||||||
|
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
||||||
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||||
|
trainer._sync_lora_adapter.assert_called_once()
|
||||||
|
|
||||||
|
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
||||||
|
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
||||||
|
from axolotl.core.trainers.grpo.async_trainer import (
|
||||||
|
AsyncGRPOTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
||||||
|
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||||
|
trainer._sync_lora_adapter.assert_not_called()
|
||||||
|
|
||||||
|
# global_step=4 → 4 % 4 == 0 → sync
|
||||||
|
trainer.state.global_step = 4
|
||||||
|
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||||
|
trainer._sync_lora_adapter.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||||
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
|
|
||||||
smoe_A = torch.zeros(
|
|
||||||
rank * num_experts,
|
|
||||||
K_inter,
|
|
||||||
device=peft_A.device,
|
|
||||||
dtype=peft_A.dtype,
|
|
||||||
)
|
|
||||||
smoe_B = torch.zeros(
|
|
||||||
N_hidden,
|
|
||||||
rank * num_experts,
|
|
||||||
device=peft_A.device,
|
|
||||||
dtype=peft_A.dtype,
|
|
||||||
)
|
|
||||||
for e in range(num_experts):
|
|
||||||
s = e * rank
|
|
||||||
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
|
|
||||||
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
|
|
||||||
return smoe_A, smoe_B
|
|
||||||
|
|
||||||
def _unwrap_experts_lora(experts_module):
|
def _unwrap_experts_lora(experts_module):
|
||||||
return experts_module, None, None
|
return experts_module, None, None
|
||||||
@@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
|
|||||||
|
|
||||||
|
|
||||||
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
|
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
|
||||||
|
|
||||||
Both gate_up_proj and down_proj need the A<->B swap because
|
|
||||||
scattermoe transposes the parameter (W = param.T).
|
|
||||||
"""
|
|
||||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||||
|
|
||||||
|
|
||||||
@@ -322,14 +300,16 @@ class TestLoRABLayoutConversion:
|
|||||||
hidden, inter = 32, 16
|
hidden, inter = 32, 16
|
||||||
scaling = 2.0
|
scaling = 2.0
|
||||||
|
|
||||||
peft_A = torch.randn(E * r, hidden)
|
peft_A = torch.randn(E * r, inter)
|
||||||
peft_B = torch.randn(inter, E * r)
|
peft_B = torch.randn(hidden, E * r)
|
||||||
|
|
||||||
A_r = peft_A.reshape(E, r, hidden)
|
A_r = peft_A.reshape(E, r, inter)
|
||||||
B_r = peft_B.reshape(inter, r, E)
|
B_r = peft_B.reshape(hidden, r, E)
|
||||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||||
|
|
||||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||||
|
assert smoe_A.shape == (E * r, inter)
|
||||||
|
assert smoe_B.shape == (hidden, E * r)
|
||||||
for e in range(E):
|
for e in range(E):
|
||||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||||
@@ -342,27 +322,26 @@ class TestLoRABLayoutConversion:
|
|||||||
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
||||||
|
|
||||||
gate_up_proj param: [E, 2*inter, hidden].
|
gate_up_proj param: [E, 2*inter, hidden].
|
||||||
peft: in_features=2*inter, out_features=hidden.
|
peft: in_features=hidden, out_features=2*inter.
|
||||||
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
|
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
|
||||||
|
|
||||||
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
||||||
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
||||||
|
|
||||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
|
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
|
||||||
"""
|
"""
|
||||||
E, r = 4, 2
|
E, r = 4, 2
|
||||||
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
||||||
scaling = 2.0
|
scaling = 2.0
|
||||||
|
|
||||||
# peft assigns: in_features=2*inter, out_features=hidden
|
# peft assigns: in_features=hidden, out_features=2*inter
|
||||||
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
|
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
|
||||||
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
|
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
|
||||||
|
|
||||||
# peft delta via einsum: "o r e, e r i -> e i o"
|
A_r = peft_A.reshape(E, r, hidden)
|
||||||
A_r = peft_A.reshape(E, r, 2 * inter)
|
B_r = peft_B.reshape(2 * inter, r, E)
|
||||||
B_r = peft_B.reshape(hidden, r, E)
|
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
|
||||||
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
|
|
||||||
# = param[e] shape [2*inter, hidden]
|
# = param[e] shape [2*inter, hidden]
|
||||||
|
|
||||||
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||||
@@ -422,22 +401,22 @@ class TestPeftLoRAWeightExtraction:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# gate_up_proj [E, 2*inter, hidden]
|
# gate_up_proj [E, 2*inter, hidden]
|
||||||
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
|
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
|
||||||
assert trainable[
|
assert trainable[
|
||||||
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
||||||
].shape == (E * r, 2 * config.intermediate_size)
|
|
||||||
assert trainable[
|
|
||||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
|
||||||
].shape == (config.hidden_size, E * r)
|
|
||||||
|
|
||||||
# down_proj [E, hidden, inter]
|
|
||||||
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
|
|
||||||
assert trainable[
|
|
||||||
"base_model.model.moe.experts.lora_A.default.weight"
|
|
||||||
].shape == (E * r, config.hidden_size)
|
].shape == (E * r, config.hidden_size)
|
||||||
|
assert trainable[
|
||||||
|
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||||
|
].shape == (2 * config.intermediate_size, E * r)
|
||||||
|
|
||||||
|
# down_proj [E, hidden, inter]
|
||||||
|
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
|
||||||
|
assert trainable[
|
||||||
|
"base_model.model.moe.experts.lora_A.default.weight"
|
||||||
|
].shape == (E * r, config.intermediate_size)
|
||||||
assert trainable[
|
assert trainable[
|
||||||
"base_model.model.moe.experts.lora_B.default.weight"
|
"base_model.model.moe.experts.lora_B.default.weight"
|
||||||
].shape == (config.intermediate_size, E * r)
|
].shape == (config.hidden_size, E * r)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_peft_forward_runs(self):
|
def test_peft_forward_runs(self):
|
||||||
@@ -488,27 +467,29 @@ class TestPeftLoRAWeightExtraction:
|
|||||||
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
||||||
assert down_lora is not None, "down_proj LoRA not detected"
|
assert down_lora is not None, "down_proj LoRA not detected"
|
||||||
|
|
||||||
# Check shapes (after peft->scattermoe conversion with A<->B swap)
|
# Check shapes after peft->scattermoe conversion.
|
||||||
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
|
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
|
||||||
|
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
|
||||||
E, r = config.num_experts, 4
|
E, r = config.num_experts, 4
|
||||||
gup_A, gup_B, gup_s = gup_lora
|
gup_A, gup_B, gup_s = gup_lora
|
||||||
assert gup_A.shape == (E * r, config.hidden_size), (
|
assert gup_A.shape == (E * r, config.hidden_size), (
|
||||||
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
|
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
|
||||||
f"got {gup_A.shape}"
|
f"got {gup_A.shape}"
|
||||||
)
|
)
|
||||||
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
||||||
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
|
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
|
||||||
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
|
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
|
||||||
|
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
|
||||||
down_A, down_B, down_s = down_lora
|
down_A, down_B, down_s = down_lora
|
||||||
assert down_A.shape == (E * r, config.intermediate_size), (
|
assert down_A.shape == (E * r, config.intermediate_size), (
|
||||||
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
|
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
|
||||||
f"got {down_A.shape}"
|
f"got {down_A.shape}"
|
||||||
)
|
)
|
||||||
assert down_B.shape == (config.hidden_size, E * r), (
|
assert down_B.shape == (config.hidden_size, E * r), (
|
||||||
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
|
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
|
||||||
f"got {down_B.shape}"
|
f"got {down_B.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
|
|||||||
assert cfg.dataloader_num_workers == 0
|
assert cfg.dataloader_num_workers == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSelectWeightSyncTransport(unittest.TestCase):
|
||||||
|
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
||||||
|
|
||||||
|
def _caps(self, **kwargs):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||||
|
|
||||||
|
c = VLLMWeightSyncCapabilities(probed=True)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(c, k, v)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def test_lora_with_native_endpoint(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps(lora_filesystem=True)
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||||
|
== "lora_filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lora_with_axolotl_endpoint(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps(lora_axolotl=True)
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||||
|
== "lora_filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps(nccl=True)
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
||||||
|
== "nccl"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_param_prefers_nccl(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps(nccl=True, http_full=True)
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(
|
||||||
|
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||||
|
)
|
||||||
|
== "nccl"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_param_falls_back_to_http(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps(http_full=True)
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(
|
||||||
|
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||||
|
)
|
||||||
|
== "http_full"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_param_no_routes_returns_none(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps() # all False
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(
|
||||||
|
caps, has_lora=False, vllm_lora_sync_pref=False
|
||||||
|
)
|
||||||
|
== "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lora_no_routes_returns_none(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
||||||
|
|
||||||
|
caps = self._caps()
|
||||||
|
assert (
|
||||||
|
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
||||||
|
== "none"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProbeVllmWeightSync(unittest.TestCase):
|
||||||
|
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
||||||
|
|
||||||
|
def test_stock_vllm_with_lora_enabled(self):
|
||||||
|
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"paths": {
|
||||||
|
"/v1/models": {"get": {}},
|
||||||
|
"/v1/load_lora_adapter": {"post": {}},
|
||||||
|
"/v1/unload_lora_adapter": {"post": {}},
|
||||||
|
"/v1/completions": {"post": {}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.return_value.raise_for_status = lambda: None
|
||||||
|
mock_get.return_value.json = lambda: spec
|
||||||
|
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||||
|
|
||||||
|
assert caps.probed is True
|
||||||
|
assert caps.lora_filesystem is True
|
||||||
|
assert caps.lora_axolotl is False
|
||||||
|
assert caps.nccl is False
|
||||||
|
assert caps.http_full is False
|
||||||
|
|
||||||
|
def test_axolotl_serve_lora_full_capabilities(self):
|
||||||
|
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"paths": {
|
||||||
|
"/init_communicator/": {"post": {}},
|
||||||
|
"/update_named_param/": {"post": {}},
|
||||||
|
"/batch_update_named_params/": {"post": {}},
|
||||||
|
"/set_lora_adapter/": {"post": {}},
|
||||||
|
"/clear_lora_adapter/": {"post": {}},
|
||||||
|
"/http_update_weights/": {"post": {}},
|
||||||
|
"/v1/load_lora_adapter": {"post": {}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.return_value.raise_for_status = lambda: None
|
||||||
|
mock_get.return_value.json = lambda: spec
|
||||||
|
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||||
|
|
||||||
|
assert caps.probed is True
|
||||||
|
assert caps.nccl is True
|
||||||
|
assert caps.lora_axolotl is True
|
||||||
|
assert caps.lora_filesystem is True
|
||||||
|
assert caps.http_full is True
|
||||||
|
|
||||||
|
def test_trl_vllm_serve_nccl_only(self):
|
||||||
|
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"paths": {
|
||||||
|
"/init_communicator/": {"post": {}},
|
||||||
|
"/update_named_param/": {"post": {}},
|
||||||
|
"/batch_update_named_params/": {"post": {}},
|
||||||
|
"/close_communicator/": {"post": {}},
|
||||||
|
"/generate/": {"post": {}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.return_value.raise_for_status = lambda: None
|
||||||
|
mock_get.return_value.json = lambda: spec
|
||||||
|
caps = probe_vllm_weight_sync("http://localhost:8000")
|
||||||
|
|
||||||
|
assert caps.probed is True
|
||||||
|
assert caps.nccl is True
|
||||||
|
assert caps.lora_filesystem is False
|
||||||
|
assert caps.lora_axolotl is False
|
||||||
|
assert caps.http_full is False
|
||||||
|
|
||||||
|
def test_unreachable_server_records_error(self):
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
||||||
|
|
||||||
|
with patch("requests.get") as mock_get:
|
||||||
|
mock_get.side_effect = ConnectionError("Connection refused")
|
||||||
|
caps = probe_vllm_weight_sync("http://localhost:9999")
|
||||||
|
|
||||||
|
assert caps.probed is False
|
||||||
|
assert caps.probe_error is not None
|
||||||
|
assert "ConnectionError" in caps.probe_error
|
||||||
|
assert caps.nccl is False
|
||||||
|
assert caps.lora_filesystem is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
||||||
|
"""End-to-end test of post_trainer_create's transport-selection branch.
|
||||||
|
|
||||||
|
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
||||||
|
leaving the trainer learning in isolation while vLLM kept serving the
|
||||||
|
unmodified base model. After the fix:
|
||||||
|
|
||||||
|
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
||||||
|
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
||||||
|
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
||||||
|
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
||||||
|
- No usable transport → raises ValueError with a precise diagnosis
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fake_cfg(adapter, vllm_lora_sync):
|
||||||
|
class FakeTRL:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FakeCfg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
trl = FakeTRL()
|
||||||
|
trl.vllm_lora_sync = vllm_lora_sync
|
||||||
|
trl.vllm_server_host = "127.0.0.1"
|
||||||
|
trl.vllm_server_port = 8000
|
||||||
|
|
||||||
|
cfg = FakeCfg()
|
||||||
|
cfg.nemo_gym_enabled = True
|
||||||
|
cfg.nemo_gym_model_name = None
|
||||||
|
cfg.base_model = "test/model"
|
||||||
|
cfg.nemo_gym_verify_timeout = 30
|
||||||
|
cfg.nemo_gym_multi_turn = True
|
||||||
|
cfg.adapter = adapter
|
||||||
|
cfg.trl = trl
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fake_trainer():
|
||||||
|
class FakeVLLMGen:
|
||||||
|
sync_weights = staticmethod(lambda: None)
|
||||||
|
|
||||||
|
class FakeTrainer:
|
||||||
|
vllm_generation = FakeVLLMGen()
|
||||||
|
|
||||||
|
return FakeTrainer()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _caps(**kwargs):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
||||||
|
|
||||||
|
c = VLLMWeightSyncCapabilities(probed=True)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(c, k, v)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
||||||
|
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(plugin, "_setup_lora_sync") as setup,
|
||||||
|
patch.object(plugin, "_check_lora_endpoint") as check,
|
||||||
|
patch.object(plugin, "_wire_multi_turn") as wire,
|
||||||
|
):
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
setup.assert_called_once()
|
||||||
|
check.assert_called_once()
|
||||||
|
wire.assert_called_once()
|
||||||
|
|
||||||
|
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
plugin._vllm_caps = self._caps() # all False, but probed
|
||||||
|
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
msg = str(ctx.exception)
|
||||||
|
assert "no-op trainer" in msg
|
||||||
|
assert "load_lora_adapter" in msg
|
||||||
|
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
||||||
|
|
||||||
|
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
plugin._vllm_caps = self._caps(nccl=True)
|
||||||
|
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
|
||||||
|
with patch.object(plugin, "_wire_multi_turn") as wire:
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
wire.assert_called_once()
|
||||||
|
|
||||||
|
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
plugin._vllm_caps = self._caps(http_full=True)
|
||||||
|
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
with self.assertRaises(NotImplementedError) as ctx:
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
assert "HTTP weight sync" in str(ctx.exception)
|
||||||
|
|
||||||
|
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
plugin._vllm_caps = self._caps()
|
||||||
|
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
msg = str(ctx.exception)
|
||||||
|
assert "no-op trainer" in msg
|
||||||
|
assert "init_communicator" in msg
|
||||||
|
assert "http_update_weights" in msg
|
||||||
|
|
||||||
|
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
||||||
|
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
||||||
|
|
||||||
|
plugin = NemoGymPlugin()
|
||||||
|
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
||||||
|
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
||||||
|
# probed=False, so the error path should mention probing.
|
||||||
|
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
||||||
|
trainer = self._fake_trainer()
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
plugin.post_trainer_create(cfg, trainer)
|
||||||
|
assert "could not probe" in str(ctx.exception)
|
||||||
|
|
||||||
|
|
||||||
class TestNemoGymE2E(unittest.TestCase):
|
class TestNemoGymE2E(unittest.TestCase):
|
||||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||||
|
|
||||||
@@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
trainer = self._make_mock_trainer()
|
trainer = self._make_mock_trainer()
|
||||||
producer._trainer = trainer
|
producer._trainer = trainer
|
||||||
|
|
||||||
# Mock the prompt iterator (returns a batch of 1 input)
|
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
||||||
producer._prompt_iter = iter(
|
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
||||||
[
|
# copies of each unique prompt — one entry per rollout.
|
||||||
[
|
_prompt_batch = [
|
||||||
{
|
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||||
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||||
}
|
|
||||||
]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
producer._prompt_dl = [
|
|
||||||
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
|
||||||
]
|
]
|
||||||
|
producer._prompt_iter = iter([_prompt_batch])
|
||||||
|
producer._prompt_dl = [_prompt_batch]
|
||||||
|
|
||||||
# Call produce
|
# Call produce
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
@@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
producer._request_timeout = 30
|
producer._request_timeout = 30
|
||||||
producer._num_generations = 2
|
producer._num_generations = 2
|
||||||
producer._trainer = self._make_mock_trainer()
|
producer._trainer = self._make_mock_trainer()
|
||||||
producer._prompt_iter = iter(
|
# RepeatSampler pre-expands by num_generations=2.
|
||||||
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
_prompt_batch = [
|
||||||
)
|
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||||
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||||
|
]
|
||||||
|
producer._prompt_iter = iter([_prompt_batch])
|
||||||
|
producer._prompt_dl = [_prompt_batch]
|
||||||
|
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,51 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestPeftScatterMoELoRALayout:
|
||||||
|
"""CPU-only guards for PEFT target_parameters layout conversion."""
|
||||||
|
|
||||||
|
def test_peft_layout_keeps_a_and_reorders_b(self):
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||||
|
peft_lora_to_scattermoe,
|
||||||
|
)
|
||||||
|
|
||||||
|
E, r, K, N = 3, 2, 5, 7
|
||||||
|
scaling = 2.0
|
||||||
|
peft_A = torch.randn(E * r, K)
|
||||||
|
peft_B = torch.randn(N, E * r)
|
||||||
|
|
||||||
|
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||||
|
|
||||||
|
assert smoe_A is peft_A
|
||||||
|
assert smoe_A.shape == (E * r, K)
|
||||||
|
assert smoe_B.shape == (N, E * r)
|
||||||
|
|
||||||
|
A_r = peft_A.reshape(E, r, K)
|
||||||
|
B_r = peft_B.reshape(N, r, E)
|
||||||
|
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||||
|
|
||||||
|
for e in range(E):
|
||||||
|
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||||
|
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||||
|
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
|
||||||
|
|
||||||
|
def test_swapped_layout_fails_before_kernel_dispatch(self):
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||||
|
validate_scattermoe_lora_shapes,
|
||||||
|
)
|
||||||
|
|
||||||
|
E, r, K, N = 3, 2, 5, 7
|
||||||
|
expert_weights = torch.empty(E, K, N)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
|
||||||
|
validate_scattermoe_lora_shapes(
|
||||||
|
expert_weights=expert_weights,
|
||||||
|
lora_A=torch.empty(E * r, N),
|
||||||
|
lora_B=torch.empty(K, E * r),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
416
tests/kernels/test_gemma4_fused_rope.py
Normal file
416
tests/kernels/test_gemma4_fused_rope.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""
|
||||||
|
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
|
||||||
|
|
||||||
|
Tests forward and backward against the reference Gemma4 implementation
|
||||||
|
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
|
||||||
|
(head_dim=256) and global attention (head_dim=512) layer configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Skip entire module if no CUDA
|
||||||
|
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_norm_rope(x, weight, cos, sin, eps):
|
||||||
|
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
D = x.shape[-1]
|
||||||
|
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||||
|
norm.weight.data.copy_(weight)
|
||||||
|
normed = norm(x)
|
||||||
|
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_norm_noscale(x, eps):
|
||||||
|
"""Reference: Gemma4RMSNorm with_scale=False."""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||||
|
|
||||||
|
D = x.shape[-1]
|
||||||
|
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
|
||||||
|
return norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
|
||||||
|
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
|
||||||
|
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
|
||||||
|
the trailing columns passed through unchanged. Mirrors how Llama-style
|
||||||
|
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
|
||||||
|
"""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
D = x.shape[-1]
|
||||||
|
n_rot = cos.shape[-1]
|
||||||
|
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||||
|
norm.weight.data.copy_(weight)
|
||||||
|
normed = norm(x)
|
||||||
|
if n_rot == D:
|
||||||
|
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||||
|
x_rot = normed[..., :n_rot]
|
||||||
|
x_pass = normed[..., n_rot:]
|
||||||
|
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
|
||||||
|
return torch.cat([rotated, x_pass], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
(2, 64, 32, 256), # sliding window layer shape
|
||||||
|
(2, 64, 4, 512), # global attention layer shape
|
||||||
|
(1, 128, 16, 256), # different batch/seq
|
||||||
|
(1, 1, 1, 8), # minimal size
|
||||||
|
],
|
||||||
|
ids=["sliding_256", "global_512", "varied", "minimal"],
|
||||||
|
)
|
||||||
|
def shapes(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
|
||||||
|
def dtype(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormRoPEForward:
|
||||||
|
"""Forward pass correctness."""
|
||||||
|
|
||||||
|
def test_matches_reference(self, shapes, dtype):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=dtype)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||||
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
def test_output_shape(self, shapes):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormRoPEBackward:
|
||||||
|
"""Backward pass correctness via gradient comparison."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D",
|
||||||
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||||
|
ids=["sliding_256", "global_512"],
|
||||||
|
)
|
||||||
|
def test_x_grad_matches_reference(self, B, S, H, D):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference backward
|
||||||
|
x_ref = torch.randn(
|
||||||
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||||
|
norm_ref.weight.data.copy_(weight_init)
|
||||||
|
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
|
||||||
|
y_ref.sum().backward()
|
||||||
|
|
||||||
|
# Fused backward
|
||||||
|
x_fused = x_ref.data.clone().requires_grad_(True)
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||||
|
y_fused.sum().backward()
|
||||||
|
|
||||||
|
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||||
|
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D",
|
||||||
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||||
|
ids=["sliding_256", "global_512"],
|
||||||
|
)
|
||||||
|
def test_weight_grad_matches_reference(self, B, S, H, D):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||||
|
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||||
|
apply_rotary_pos_emb(
|
||||||
|
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
|
||||||
|
).sum().backward()
|
||||||
|
|
||||||
|
# Fused
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||||
|
|
||||||
|
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||||
|
w_fused.grad.flatten().float(),
|
||||||
|
norm_ref.weight.grad.flatten().float(),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
assert cos_sim_w > 0.995, (
|
||||||
|
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_grad_flows(self):
|
||||||
|
"""Verify gradients are non-zero and finite."""
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = 1, 16, 4, 64
|
||||||
|
x = torch.randn(
|
||||||
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
|
||||||
|
y.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None, "x.grad is None"
|
||||||
|
assert w.grad is not None, "w.grad is None"
|
||||||
|
assert x.grad.isfinite().all(), "x.grad has non-finite values"
|
||||||
|
assert w.grad.isfinite().all(), "w.grad has non-finite values"
|
||||||
|
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
|
||||||
|
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormRoPEPartialRotary:
|
||||||
|
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
|
||||||
|
|
||||||
|
Compares against the original primitives (`Gemma4RMSNorm` +
|
||||||
|
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
|
||||||
|
columns passed through. Without the kernel fix this used to crash with
|
||||||
|
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D,n_rot",
|
||||||
|
[
|
||||||
|
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
|
||||||
|
(2, 16, 4, 64, 16), # quarter rotary
|
||||||
|
(2, 32, 8, 128, 64), # half rotary, larger heads
|
||||||
|
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
|
||||||
|
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
|
||||||
|
],
|
||||||
|
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
|
||||||
|
)
|
||||||
|
def test_forward_matches_reference(self, B, S, H, D, n_rot):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||||
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||||
|
|
||||||
|
assert y_fused.shape == y_ref.shape == (B, S, H, D)
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, (
|
||||||
|
f"partial rotary forward cosine_sim={cos_sim:.6f} "
|
||||||
|
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The pass-through tail must equal the reference RMSNorm output bit-
|
||||||
|
# for-bit (any deviation would mean the kernel is touching it with a
|
||||||
|
# spurious rotation, which is the original bug class).
|
||||||
|
torch.testing.assert_close(
|
||||||
|
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D,n_rot",
|
||||||
|
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||||
|
ids=["half_64", "quarter_256"],
|
||||||
|
)
|
||||||
|
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference backward via the original primitives
|
||||||
|
x_ref = x_data.clone().requires_grad_(True)
|
||||||
|
w_ref = weight_init.clone()
|
||||||
|
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
|
||||||
|
y_ref.sum().backward()
|
||||||
|
|
||||||
|
# Fused backward
|
||||||
|
x_fused = x_data.clone().requires_grad_(True)
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||||
|
y_fused.sum().backward()
|
||||||
|
|
||||||
|
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||||
|
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D,n_rot",
|
||||||
|
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
||||||
|
ids=["half_64", "quarter_256"],
|
||||||
|
)
|
||||||
|
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
|
||||||
|
# rotary applied to the rotated slice.
|
||||||
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||||
|
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||||
|
normed = norm_ref(x_data)
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
|
||||||
|
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
|
||||||
|
y_ref.sum().backward()
|
||||||
|
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||||
|
|
||||||
|
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||||
|
w_fused.grad.flatten().float(),
|
||||||
|
norm_ref.weight.grad.flatten().float(),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
assert cos_sim_w > 0.995, (
|
||||||
|
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
|
||||||
|
"""Regression: passing cos/sin with shape == head_dim must still
|
||||||
|
match the full-rotary reference (the partial-rotary code path must
|
||||||
|
not perturb the existing full-rotary output)."""
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = 2, 16, 4, 64
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||||
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
|
||||||
|
|
||||||
|
def test_validation_errors(self):
|
||||||
|
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
|
||||||
|
Triton crash deeper in the kernel)."""
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = 1, 4, 2, 64
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# n_rot > head_dim
|
||||||
|
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
||||||
|
with pytest.raises(ValueError, match="cannot exceed head_dim"):
|
||||||
|
fused_rms_norm_rope(x, w, cos_big, sin_big)
|
||||||
|
|
||||||
|
# cos/sin last-dim mismatch
|
||||||
|
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
|
||||||
|
with pytest.raises(ValueError, match="same last dim"):
|
||||||
|
fused_rms_norm_rope(x, w, cos, sin)
|
||||||
|
|
||||||
|
# odd rotary dim
|
||||||
|
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
||||||
|
with pytest.raises(ValueError, match="must be even"):
|
||||||
|
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormNoScale:
|
||||||
|
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||||
|
|
||||||
|
def test_forward_matches_reference(self, shapes, dtype):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
y_ref = _reference_norm_noscale(x.clone(), eps)
|
||||||
|
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
def test_backward_flows(self):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||||
|
|
||||||
|
x = torch.randn(
|
||||||
|
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
y = fused_rms_norm_noscale(x, eps=1e-6)
|
||||||
|
y.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.isfinite().all()
|
||||||
|
assert x.grad.abs().sum() > 0
|
||||||
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
219
tests/monkeypatch/test_gemma4_fused_attn.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Tests for the Gemma 4 fused-attention monkey-patch.
|
||||||
|
|
||||||
|
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
||||||
|
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
||||||
|
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
||||||
|
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
||||||
|
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
||||||
|
exercised end-to-end (this is the bug originally documented in
|
||||||
|
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
||||||
|
|
||||||
|
The full-model forward also pins that the fused forward keeps accepting
|
||||||
|
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
||||||
|
installed transformers version — so any future signature drift on
|
||||||
|
upstream's side trips a clear failure here instead of a confusing
|
||||||
|
TypeError deep in a training run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
|
||||||
|
]
|
||||||
|
|
||||||
|
pytest.importorskip(
|
||||||
|
"transformers.models.gemma4",
|
||||||
|
reason="fused_attn patch only matters when Gemma 4 is available",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def restore_gemma4_attention():
|
||||||
|
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
|
||||||
|
so the monkey-patch does not leak across the suite."""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||||
|
|
||||||
|
saved = Gemma4TextAttention.forward
|
||||||
|
yield Gemma4TextAttention
|
||||||
|
Gemma4TextAttention.forward = saved
|
||||||
|
|
||||||
|
|
||||||
|
def _build_hybrid_config():
|
||||||
|
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
|
||||||
|
layer with proportional rope and partial_rotary_factor=0.25. This is
|
||||||
|
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
|
||||||
|
enough to fit on any GPU."""
|
||||||
|
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||||
|
|
||||||
|
cfg = Gemma4TextConfig(
|
||||||
|
vocab_size=128,
|
||||||
|
hidden_size=64,
|
||||||
|
intermediate_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
head_dim=32,
|
||||||
|
global_head_dim=64,
|
||||||
|
layer_types=["sliding_attention", "full_attention"],
|
||||||
|
sliding_window=64,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
hidden_size_per_layer_input=16,
|
||||||
|
vocab_size_per_layer_input=128,
|
||||||
|
rope_parameters={
|
||||||
|
"sliding_attention": {
|
||||||
|
"rope_type": "default",
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
},
|
||||||
|
"full_attention": {
|
||||||
|
"rope_type": "proportional",
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cfg._attn_implementation = "sdpa"
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model(seed=0):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
cfg = _build_hybrid_config()
|
||||||
|
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedAttnSignature:
|
||||||
|
"""The fused forward must accept the same call shape as
|
||||||
|
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
||||||
|
version. Any signature drift surfaces here as a TypeError."""
|
||||||
|
|
||||||
|
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
||||||
|
"""Run a model forward that exercises the real
|
||||||
|
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
||||||
|
the fused patch installed."""
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = _build_model()
|
||||||
|
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||||
|
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
with torch.no_grad():
|
||||||
|
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||||
|
|
||||||
|
assert out.shape == (2, 16, 64)
|
||||||
|
assert torch.isfinite(out).all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedAttnPerLayerCorrectness:
|
||||||
|
"""Compare the patched attention layer to the stock implementation
|
||||||
|
on a single forward call. This isolates the fused kernel correctness
|
||||||
|
from cross-layer numerical drift."""
|
||||||
|
|
||||||
|
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
|
||||||
|
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
|
||||||
|
installed) for one layer and return the output."""
|
||||||
|
attn = model.layers[layer_idx].self_attn
|
||||||
|
layer_type = model.config.layer_types[layer_idx]
|
||||||
|
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
|
||||||
|
out, _ = attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=(cos, sin),
|
||||||
|
attention_mask=None,
|
||||||
|
shared_kv_states={},
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"layer_idx",
|
||||||
|
[0, 1],
|
||||||
|
ids=["sliding_head32", "global_head64_proportional"],
|
||||||
|
)
|
||||||
|
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = _build_model(seed=1)
|
||||||
|
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
|
||||||
|
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref = self._run_attention(m, layer_idx, hs, pos)
|
||||||
|
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
with torch.no_grad():
|
||||||
|
got = self._run_attention(m, layer_idx, hs, pos)
|
||||||
|
|
||||||
|
assert got.shape == ref.shape
|
||||||
|
assert torch.isfinite(got).all()
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
ref.flatten().float(), got.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, (
|
||||||
|
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
|
||||||
|
)
|
||||||
|
# bf16 precision: a few millis of absolute drift per element is
|
||||||
|
# acceptable for a Q/K/V projection pipeline. Anything larger is
|
||||||
|
# a real bug.
|
||||||
|
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedAttnFullModel:
|
||||||
|
"""End-to-end model forward + backward through both layer types."""
|
||||||
|
|
||||||
|
def test_full_forward_matches_stock(self, restore_gemma4_attention):
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = _build_model(seed=2)
|
||||||
|
ids = torch.randint(0, 128, (2, 32), device="cuda")
|
||||||
|
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||||
|
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
with torch.no_grad():
|
||||||
|
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
||||||
|
|
||||||
|
assert got.shape == ref.shape
|
||||||
|
assert torch.isfinite(got).all()
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
ref.flatten().float(), got.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
|
||||||
|
# accumulates a small amount of numerical drift; we just want to
|
||||||
|
# pin that the two paths are computing the same function.
|
||||||
|
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
|
||||||
|
|
||||||
|
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
|
||||||
|
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
|
||||||
|
for both the sliding and proportional-rope layers."""
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = _build_model(seed=3).train()
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
|
||||||
|
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
||||||
|
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
||||||
|
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
# Both layers must accumulate gradients on q_norm.weight and
|
||||||
|
# k_norm.weight — that proves the fused kernel ran the backward.
|
||||||
|
for i, layer in enumerate(m.layers[:2]):
|
||||||
|
attn = layer.self_attn
|
||||||
|
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
|
||||||
|
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
|
||||||
|
assert attn.q_norm.weight.grad.isfinite().all()
|
||||||
|
assert attn.k_norm.weight.grad.isfinite().all()
|
||||||
|
assert attn.q_norm.weight.grad.abs().sum() > 0
|
||||||
|
assert attn.k_norm.weight.grad.abs().sum() > 0
|
||||||
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
343
tests/monkeypatch/test_gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
"""Tests for the Gemma 4 hybrid-attention mask fix.
|
||||||
|
|
||||||
|
These tests pin the single critical behavior: after installing the patch,
|
||||||
|
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
|
||||||
|
the underlying mask builder regardless of what the caller's config says.
|
||||||
|
This is what keeps full-attention (head_dim=512) global layers from
|
||||||
|
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
|
||||||
|
the 2D FA2 mask that would be built from the model-level config.
|
||||||
|
|
||||||
|
The tests use a mocked ``create_causal_mask`` so they don't have to load
|
||||||
|
a real 26B Gemma 4 model or even have access to its weights. What matters
|
||||||
|
for the bug fix is which config is handed to the mask factory, not the
|
||||||
|
factory's actual output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip(
|
||||||
|
"transformers.models.gemma4",
|
||||||
|
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def restore_gemma4_module():
|
||||||
|
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
|
||||||
|
each test so patch state doesn't leak across the suite."""
|
||||||
|
from transformers.models.gemma4 import modeling_gemma4
|
||||||
|
|
||||||
|
saved = modeling_gemma4.create_causal_mask
|
||||||
|
yield modeling_gemma4
|
||||||
|
modeling_gemma4.create_causal_mask = saved
|
||||||
|
# Reset the module-level flag so the next test can re-install cleanly.
|
||||||
|
from axolotl.monkeypatch import gemma4_hybrid_mask
|
||||||
|
|
||||||
|
gemma4_hybrid_mask._PATCH_APPLIED = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
original = modeling_gemma4.create_causal_mask
|
||||||
|
assert patch_gemma4_hybrid_mask() is True
|
||||||
|
|
||||||
|
assert modeling_gemma4.create_causal_mask is not original
|
||||||
|
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
|
||||||
|
"patched wrapper must expose the original reference for teardown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_is_idempotent(restore_gemma4_module):
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
wrapper_first = modeling_gemma4.create_causal_mask
|
||||||
|
|
||||||
|
# Second call must not re-wrap the already-wrapped function (which
|
||||||
|
# would leak the original reference through a chain of wrappers).
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
wrapper_second = modeling_gemma4.create_causal_mask
|
||||||
|
|
||||||
|
assert wrapper_first is wrapper_second
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
|
||||||
|
"""Core invariant: when the patched wrapper is called with a config
|
||||||
|
that says ``flash_attention_2``, the underlying mask factory receives
|
||||||
|
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
|
||||||
|
|
||||||
|
Without this, the full-attention global layers get a 2D FA2 mask and
|
||||||
|
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
|
||||||
|
"""
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
# Swap in a mock BEFORE installing the patch so the wrapper captures
|
||||||
|
# it as the "original". The mock records every call so we can inspect
|
||||||
|
# what config got passed through.
|
||||||
|
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
|
||||||
|
modeling_gemma4.create_causal_mask = mock_factory
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
|
||||||
|
# Caller-supplied config says FA2 (that's the model-level setting).
|
||||||
|
caller_config = SimpleNamespace(
|
||||||
|
_attn_implementation="flash_attention_2",
|
||||||
|
head_dim=512,
|
||||||
|
some_other_attr="preserved",
|
||||||
|
)
|
||||||
|
result = modeling_gemma4.create_causal_mask(
|
||||||
|
caller_config,
|
||||||
|
inputs_embeds=None,
|
||||||
|
attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
position_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrapper returned whatever the mock returned — no transformation of
|
||||||
|
# the result itself.
|
||||||
|
assert result == "mask_4d"
|
||||||
|
|
||||||
|
# The mock was called exactly once with a config whose
|
||||||
|
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
|
||||||
|
assert mock_factory.call_count == 1
|
||||||
|
(passed_config, *_), passed_kwargs = mock_factory.call_args
|
||||||
|
assert passed_config._attn_implementation == "sdpa"
|
||||||
|
|
||||||
|
# The wrapper must NOT mutate the caller's config in place — other
|
||||||
|
# mask builders (e.g. create_sliding_window_causal_mask) read from
|
||||||
|
# the same config and must still see fa2.
|
||||||
|
assert caller_config._attn_implementation == "flash_attention_2"
|
||||||
|
|
||||||
|
# Other attributes on the config must be preserved so the underlying
|
||||||
|
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
|
||||||
|
assert passed_config.head_dim == 512
|
||||||
|
assert passed_config.some_other_attr == "preserved"
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
|
||||||
|
"""The wrapper must forward positional + keyword args to the original
|
||||||
|
unchanged, so transformers' own call-site in Gemma4TextModel.forward
|
||||||
|
keeps working across minor transformers-version signature drift."""
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
mock_factory = MagicMock(return_value="mask")
|
||||||
|
modeling_gemma4.create_causal_mask = mock_factory
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
|
||||||
|
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
|
||||||
|
modeling_gemma4.create_causal_mask(
|
||||||
|
caller_config,
|
||||||
|
"positional_arg",
|
||||||
|
inputs_embeds="embeds",
|
||||||
|
attention_mask="mask_2d",
|
||||||
|
past_key_values="cache",
|
||||||
|
position_ids="positions",
|
||||||
|
or_mask_function="or_fn",
|
||||||
|
)
|
||||||
|
|
||||||
|
args, kwargs = mock_factory.call_args
|
||||||
|
# First positional (after config override) is preserved.
|
||||||
|
assert args[1] == "positional_arg"
|
||||||
|
# All kwargs are forwarded untouched.
|
||||||
|
assert kwargs["inputs_embeds"] == "embeds"
|
||||||
|
assert kwargs["attention_mask"] == "mask_2d"
|
||||||
|
assert kwargs["past_key_values"] == "cache"
|
||||||
|
assert kwargs["position_ids"] == "positions"
|
||||||
|
assert kwargs["or_mask_function"] == "or_fn"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unpatch_restores_original(restore_gemma4_module):
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import (
|
||||||
|
patch_gemma4_hybrid_mask,
|
||||||
|
unpatch_gemma4_hybrid_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
sentinel = MagicMock(name="original")
|
||||||
|
modeling_gemma4.create_causal_mask = sentinel
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
assert modeling_gemma4.create_causal_mask is not sentinel
|
||||||
|
|
||||||
|
unpatch_gemma4_hybrid_mask()
|
||||||
|
assert modeling_gemma4.create_causal_mask is sentinel
|
||||||
|
|
||||||
|
|
||||||
|
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
# Should be a no-op, no exception.
|
||||||
|
unpatch_gemma4_hybrid_mask()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
|
||||||
|
"""Only ``create_causal_mask`` is overridden — the sliding-window
|
||||||
|
factory must remain bound to its original to preserve FA2 masks for
|
||||||
|
the sliding-attention layers. If we accidentally patch both, the
|
||||||
|
sliding layers get SDPA format and lose the FA2 speedup."""
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
|
||||||
|
pytest.skip("transformers version has no create_sliding_window_causal_mask")
|
||||||
|
|
||||||
|
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
|
||||||
|
assert sliding_after is sliding_before
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
|
||||||
|
#
|
||||||
|
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
|
||||||
|
# model with 2 layers (one sliding, one full_attention), apply the hybrid
|
||||||
|
# attention path end-to-end, and run a forward pass with a padded
|
||||||
|
# attention_mask at a long-ish seq len. The invariant we're pinning is that
|
||||||
|
# the full_attention layer does not crash with the
|
||||||
|
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
|
||||||
|
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
|
||||||
|
# tokens in the FSDP2 training run.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tiny_gemma4_text_model():
|
||||||
|
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
|
||||||
|
import torch
|
||||||
|
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
||||||
|
|
||||||
|
cfg = Gemma4TextConfig(
|
||||||
|
vocab_size=128,
|
||||||
|
hidden_size=64,
|
||||||
|
intermediate_size=128,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
head_dim=32,
|
||||||
|
layer_types=["sliding_attention", "full_attention"],
|
||||||
|
sliding_window=64,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
hidden_size_per_layer_input=16,
|
||||||
|
vocab_size_per_layer_input=128,
|
||||||
|
)
|
||||||
|
# Caller-supplied attn impl simulates the pilot config (fa2 at model
|
||||||
|
# level). The hybrid patch is what makes this survive long context.
|
||||||
|
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
|
||||||
|
torch.manual_seed(42)
|
||||||
|
model = Gemma4TextModel(cfg).eval()
|
||||||
|
return model, cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_hybrid_attn_inline(model, cfg):
|
||||||
|
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
|
||||||
|
to a model, without needing a full PatchManager / pydantic cfg."""
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
for layer_idx, layer in enumerate(model.layers):
|
||||||
|
if cfg.layer_types[layer_idx] != "sliding_attention":
|
||||||
|
attn = getattr(layer, "self_attn", None)
|
||||||
|
if attn is not None and hasattr(attn, "config"):
|
||||||
|
sdpa_cfg = copy.copy(attn.config)
|
||||||
|
sdpa_cfg._attn_implementation = "sdpa"
|
||||||
|
attn.config = sdpa_cfg
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
|
||||||
|
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
|
||||||
|
Gemma4TextModel runs a forward at long context (1024 tokens) with
|
||||||
|
real padding in the attention mask, producing the expected output
|
||||||
|
shape. This exercises the actual code path that crashed the pilot
|
||||||
|
without needing a real 26B checkpoint or CUDA."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model, cfg = _build_tiny_gemma4_text_model()
|
||||||
|
_apply_hybrid_attn_inline(model, cfg)
|
||||||
|
|
||||||
|
B, S = 2, 1024
|
||||||
|
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
|
||||||
|
attn_mask = torch.ones(B, S, dtype=torch.long)
|
||||||
|
# Pad positions in the second row. Without padding, SDPA falls back to
|
||||||
|
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
|
||||||
|
# mask to exercise the actual bug site.
|
||||||
|
attn_mask[1, S // 2 :] = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
||||||
|
|
||||||
|
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
|
||||||
|
assert torch.isfinite(out.last_hidden_state).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_create_causal_mask_returns_4d_for_real_config(
|
||||||
|
restore_gemma4_module,
|
||||||
|
):
|
||||||
|
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
|
||||||
|
and verify the returned mask is a 4D tensor — which is the shape the
|
||||||
|
SDPA-patched global layers need. Without the patch and with a
|
||||||
|
caller-supplied FA2 config this would return a 2D mask and the layer
|
||||||
|
would crash at long context."""
|
||||||
|
import torch
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||||
|
|
||||||
|
patch_gemma4_hybrid_mask()
|
||||||
|
modeling_gemma4 = restore_gemma4_module
|
||||||
|
|
||||||
|
cfg = Gemma4TextConfig(
|
||||||
|
vocab_size=128,
|
||||||
|
hidden_size=64,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
head_dim=32,
|
||||||
|
layer_types=["sliding_attention", "full_attention"],
|
||||||
|
sliding_window=64,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
hidden_size_per_layer_input=16,
|
||||||
|
vocab_size_per_layer_input=128,
|
||||||
|
)
|
||||||
|
# Simulate the pilot: caller says flash_attention_2, but global layers
|
||||||
|
# were switched to SDPA per-layer. Without the patch, create_causal_mask
|
||||||
|
# would return an FA2 2D mask here and the SDPA layer would crash.
|
||||||
|
cfg._attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
|
B, S = 2, 1024
|
||||||
|
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
|
||||||
|
attention_mask = torch.ones((B, S), dtype=torch.long)
|
||||||
|
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
|
||||||
|
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
|
||||||
|
past_key_values = DynamicCache(config=cfg)
|
||||||
|
|
||||||
|
mask = modeling_gemma4.create_causal_mask(
|
||||||
|
config=cfg,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mask is not None
|
||||||
|
assert isinstance(mask, torch.Tensor)
|
||||||
|
assert mask.dim() == 4, (
|
||||||
|
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
|
||||||
|
f"shape={tuple(mask.shape)}. The full_attention global layers need "
|
||||||
|
"this shape or they crash at long context."
|
||||||
|
)
|
||||||
|
assert mask.shape[0] == B
|
||||||
|
assert mask.shape[-1] == S
|
||||||
|
assert mask.shape[-2] == S
|
||||||
|
|
||||||
|
# Caller's config must be untouched — other code paths still read it.
|
||||||
|
assert cfg._attn_implementation == "flash_attention_2"
|
||||||
@@ -916,6 +916,235 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.debug(f"Final labels: {labels}")
|
LOG.debug(f"Final labels: {labels}")
|
||||||
LOG.debug(f"Final input_ids: {input_ids}")
|
LOG.debug(f"Final input_ids: {input_ids}")
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_content_parts_training(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
LOG.info("Testing with content as list of parts with per-part training")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset where assistant content is a list of parts with per-part training
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are an AI assistant."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is 2+2?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me think...", "train": False},
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
turns = strategy.get_conversation_thread(dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Find the assistant turn (last turn)
|
||||||
|
assistant_turn_idx = len(turns) - 1
|
||||||
|
start_idx, end_idx = strategy.find_turn(
|
||||||
|
turns=turns, turn_idx=assistant_turn_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_idx != -1 and end_idx != -1, (
|
||||||
|
"Could not find assistant turn boundaries"
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||||
|
LOG.debug(f"Assistant turn decoded: {decoded}")
|
||||||
|
|
||||||
|
# Tokenize each part separately to find their boundaries
|
||||||
|
part1_text = "Let me think..."
|
||||||
|
part2_text = "The answer is 4."
|
||||||
|
|
||||||
|
# Verify the concatenated content is in the decoded output
|
||||||
|
assert part1_text in decoded, (
|
||||||
|
f"Part 1 '{part1_text}' not found in decoded: {decoded}"
|
||||||
|
)
|
||||||
|
assert part2_text in decoded, (
|
||||||
|
f"Part 2 '{part2_text}' not found in decoded: {decoded}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that part1 tokens (train=False) are masked
|
||||||
|
# and part2 tokens (train=True) are labeled
|
||||||
|
turn_labels = labels[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Find where part2 starts in the token sequence
|
||||||
|
part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"]
|
||||||
|
part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
# The first part should be masked (all IGNORE_TOKEN_ID)
|
||||||
|
# Due to token boundary alignment, check that at least the interior tokens
|
||||||
|
# of part1 are masked
|
||||||
|
assert any(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected some masked labels for train=False part, but got {turn_labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The second part should be trained (not IGNORE_TOKEN_ID)
|
||||||
|
assert any(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected some trained labels for train=True part, but got {turn_labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# More precise check: first N tokens should be masked, last M tokens should be trained
|
||||||
|
# where N ~ len(part1_tokens) and M ~ len(part2_tokens)
|
||||||
|
# Allow for token boundary effects at the boundary
|
||||||
|
num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID)
|
||||||
|
num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID)
|
||||||
|
|
||||||
|
LOG.debug(f"Turn labels: {turn_labels}")
|
||||||
|
LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The number of masked tokens should be roughly the size of part1
|
||||||
|
# and the number of trained tokens should be roughly the size of part2
|
||||||
|
assert num_masked > 0, "Expected masked tokens for the train=False part"
|
||||||
|
assert num_trained > 0, "Expected trained tokens for the train=True part"
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_content_parts_with_weight(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
LOG.info("Testing with content parts using weight field")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset using weight instead of train
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Thinking step by step: ", "weight": 0},
|
||||||
|
{"type": "text", "text": "Hello! How can I help?", "weight": 1},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
|
||||||
|
# There should be both masked and trained labels
|
||||||
|
has_masked = any(label == IGNORE_TOKEN_ID for label in labels)
|
||||||
|
has_trained = any(label != IGNORE_TOKEN_ID for label in labels)
|
||||||
|
assert has_masked, "Expected masked tokens (weight=0 part + user turn)"
|
||||||
|
assert has_trained, "Expected trained tokens (weight=1 part)"
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_content_parts_string_passthrough(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
LOG.info("Testing that string content still works alongside list content")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# All list content in the conversation
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is 2+2?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
|
||||||
|
# Should tokenize without errors
|
||||||
|
assert "input_ids" in res
|
||||||
|
assert "labels" in res
|
||||||
|
assert len(res["input_ids"]) > 0
|
||||||
|
|
||||||
def test_get_chat_template_variables(
|
def test_get_chat_template_variables(
|
||||||
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
):
|
):
|
||||||
@@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling:
|
|||||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
f"Assistant turn {i} should be unmasked"
|
f"Assistant turn {i} should be unmasked"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateReasoningContent:
|
||||||
|
"""
|
||||||
|
Test class for reasoning_content with content parts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_reasoning_content_with_content_parts(self, qwen3_tokenizer):
|
||||||
|
"""Test that reasoning_content as string + content as list parts works correctly.
|
||||||
|
Content training_detail offsets should align with content-only boundaries."""
|
||||||
|
LOG.info("Testing reasoning_content with content parts on qwen3")
|
||||||
|
|
||||||
|
tokenizer = deepcopy(qwen3_tokenizer)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template("qwen3"),
|
||||||
|
message_property_mappings={
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
"reasoning_content": "reasoning_content",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# reasoning_content is a plain string, content is list with per-part training
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": "Step 1: 2+2=4",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
turns = strategy.get_conversation_thread(dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Find the assistant turn
|
||||||
|
assistant_idx = 1
|
||||||
|
start_idx, end_idx = strategy.find_turn(
|
||||||
|
turns=turns, turn_idx=assistant_idx, content_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_idx != -1 and end_idx != -1, (
|
||||||
|
"Could not find assistant content boundaries"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4"
|
||||||
|
decoded_span = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||||
|
assert "The answer is 4." in decoded_span, (
|
||||||
|
f"Content not found in span: {decoded_span}"
|
||||||
|
)
|
||||||
|
assert "Step 1" not in decoded_span, (
|
||||||
|
f"Reasoning should not be in content-only span: {decoded_span}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that content tokens are trained
|
||||||
|
content_labels = labels[start_idx:end_idx]
|
||||||
|
assert any(label != IGNORE_TOKEN_ID for label in content_labels), (
|
||||||
|
f"Expected trained labels in content span, got {content_labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_reasoning_content_per_part_masking(self, qwen3_tokenizer):
|
||||||
|
"""Test masking incorrect reasoning while training on self-correction.
|
||||||
|
This is the core use case: mask out wrong thoughts, train on corrections."""
|
||||||
|
LOG.info("Testing reasoning_content per-part masking on qwen3")
|
||||||
|
|
||||||
|
tokenizer = deepcopy(qwen3_tokenizer)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template("qwen3"),
|
||||||
|
message_property_mappings={
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
"reasoning_content": "reasoning_content",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reasoning has wrong step (masked) then self-correction (trained)
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": [
|
||||||
|
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": False},
|
||||||
|
{"type": "text", "text": " Wait no, 2+2=4.", "train": True},
|
||||||
|
],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
turns = strategy.get_conversation_thread(dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Find reasoning boundaries
|
||||||
|
reasoning_start, reasoning_end = strategy.find_turn(
|
||||||
|
turns=turns, turn_idx=1, reasoning_only=True
|
||||||
|
)
|
||||||
|
assert reasoning_start != -1 and reasoning_end != -1, (
|
||||||
|
"Could not find reasoning boundaries"
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end])
|
||||||
|
LOG.debug(f"Reasoning span: {decoded_reasoning!r}")
|
||||||
|
assert "2+2=5" in decoded_reasoning, (
|
||||||
|
f"Wrong step not in reasoning span: {decoded_reasoning}"
|
||||||
|
)
|
||||||
|
assert "2+2=4" in decoded_reasoning, (
|
||||||
|
f"Correction not in reasoning span: {decoded_reasoning}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify reasoning labels have both masked and trained tokens
|
||||||
|
reasoning_labels = labels[reasoning_start:reasoning_end]
|
||||||
|
reasoning_ids = input_ids[reasoning_start:reasoning_end]
|
||||||
|
|
||||||
|
# Decode only the trained tokens — should be exactly the self-correction
|
||||||
|
trained_ids = [
|
||||||
|
tid
|
||||||
|
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
|
||||||
|
if lab != IGNORE_TOKEN_ID
|
||||||
|
]
|
||||||
|
trained_text = tokenizer.decode(trained_ids)
|
||||||
|
assert trained_text.strip() == "Wait no, 2+2=4.", (
|
||||||
|
f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode only the masked tokens — should be exactly the incorrect step
|
||||||
|
masked_ids = [
|
||||||
|
tid
|
||||||
|
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
|
||||||
|
if lab == IGNORE_TOKEN_ID
|
||||||
|
]
|
||||||
|
masked_text = tokenizer.decode(masked_ids)
|
||||||
|
assert masked_text.strip() == "Hmm maybe 2+2=5.", (
|
||||||
|
f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find content boundaries
|
||||||
|
content_start, content_end = strategy.find_turn(
|
||||||
|
turns=turns, turn_idx=1, content_only=True
|
||||||
|
)
|
||||||
|
assert content_start != -1 and content_end != -1, (
|
||||||
|
"Could not find content boundaries"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Content should be fully trained — decode trained tokens to verify
|
||||||
|
content_labels = labels[content_start:content_end]
|
||||||
|
content_ids = input_ids[content_start:content_end]
|
||||||
|
content_trained_ids = [
|
||||||
|
tid
|
||||||
|
for tid, lab in zip(content_ids, content_labels, strict=True)
|
||||||
|
if lab != IGNORE_TOKEN_ID
|
||||||
|
]
|
||||||
|
content_trained_text = tokenizer.decode(content_trained_ids)
|
||||||
|
assert "The answer is 4." in content_trained_text, (
|
||||||
|
f"Expected 'The answer is 4.' in trained content tokens, "
|
||||||
|
f"got: {content_trained_text!r}"
|
||||||
|
)
|
||||||
|
assert all(label != IGNORE_TOKEN_ID for label in content_labels), (
|
||||||
|
f"Expected all content labels trained, got {content_labels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer):
|
||||||
|
"""Test that reasoning_content as list without training flags still works."""
|
||||||
|
LOG.info("Testing reasoning_content as list without training flags on qwen3")
|
||||||
|
|
||||||
|
tokenizer = deepcopy(qwen3_tokenizer)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template("qwen3"),
|
||||||
|
message_property_mappings={
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
"reasoning_content": "reasoning_content",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both as lists, no per-part training flags
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": [
|
||||||
|
{"type": "text", "text": "Step 1: addition."},
|
||||||
|
{"type": "text", "text": " Step 2: 2+2=4."},
|
||||||
|
],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||||
|
res = strategy.tokenize_prompt(dataset[0])
|
||||||
|
|
||||||
|
# Should tokenize without errors
|
||||||
|
assert "input_ids" in res
|
||||||
|
assert "labels" in res
|
||||||
|
assert len(res["input_ids"]) > 0
|
||||||
|
|
||||||
|
# Verify the full output contains both reasoning and content
|
||||||
|
full_text = tokenizer.decode(res["input_ids"])
|
||||||
|
assert "Step 1: addition." in full_text
|
||||||
|
assert "Step 2: 2+2=4." in full_text
|
||||||
|
assert "The answer is 4." in full_text
|
||||||
|
|||||||
@@ -65,47 +65,57 @@ def test_singleton_instance(telemetry_manager_class):
|
|||||||
assert telemetry_manager_class.get_instance() is first
|
assert telemetry_manager_class.get_instance() is first
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
class TestTelemetryOptOut:
|
||||||
"""Test that telemetry is enabled by default (opt-out)"""
|
"""
|
||||||
with (
|
Telemetry is opt-out: enabled by default, disabled by AXOLOTL_DO_NOT_TRACK
|
||||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
or DO_NOT_TRACK. Each env var is checked independently — setting either one
|
||||||
patch("time.sleep"),
|
to a truthy value ("1" or "true") disables telemetry.
|
||||||
patch("logging.Logger.info"),
|
|
||||||
|
The parametrized table below is the source of truth for expected behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# AXOLOTL_DO_NOT_TRACK DO_NOT_TRACK expected
|
||||||
|
@pytest.mark.parametrize("axolotl_dnt, dnt, expected", [
|
||||||
|
# --- Neither var set: telemetry ON ---
|
||||||
|
(None, None, True),
|
||||||
|
|
||||||
|
# --- Only AXOLOTL_DO_NOT_TRACK set ---
|
||||||
|
("0", None, True), # explicit opt-in
|
||||||
|
("false", None, True), # explicit opt-in
|
||||||
|
("1", None, False), # opt-out
|
||||||
|
("true", None, False), # opt-out
|
||||||
|
(" 1 ", None, False), # whitespace-padded opt-out
|
||||||
|
|
||||||
|
# --- Only DO_NOT_TRACK set (was broken before fix) ---
|
||||||
|
(None, "0", True), # explicit opt-in
|
||||||
|
(None, "false", True), # explicit opt-in
|
||||||
|
(None, "1", False), # opt-out
|
||||||
|
(None, "true", False), # opt-out
|
||||||
|
|
||||||
|
# --- Both set: either truthy → disabled ---
|
||||||
|
("0", "1", False), # DO_NOT_TRACK wins
|
||||||
|
("1", "0", False), # AXOLOTL_DO_NOT_TRACK wins
|
||||||
|
("1", "1", False), # both opt-out
|
||||||
|
("0", "0", True), # both opt-in
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
def test_do_not_track_env_vars(
|
||||||
|
self, telemetry_manager_class, axolotl_dnt, dnt, expected
|
||||||
):
|
):
|
||||||
manager = telemetry_manager_class()
|
env = {"RANK": "0"}
|
||||||
assert manager.enabled
|
if axolotl_dnt is not None:
|
||||||
|
env["AXOLOTL_DO_NOT_TRACK"] = axolotl_dnt
|
||||||
|
if dnt is not None:
|
||||||
|
env["DO_NOT_TRACK"] = dnt
|
||||||
|
|
||||||
|
with (
|
||||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
patch.dict(os.environ, env, clear=True),
|
||||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
patch("time.sleep"),
|
||||||
with (
|
patch("logging.Logger.info"),
|
||||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
):
|
||||||
patch("time.sleep"),
|
manager = telemetry_manager_class()
|
||||||
):
|
assert manager.enabled is expected
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
|
||||||
with (
|
|
||||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
|
||||||
patch("time.sleep"),
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert not manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
|
||||||
with (
|
|
||||||
patch.dict(
|
|
||||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
|
||||||
),
|
|
||||||
patch("time.sleep"),
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert not manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||||
|
|||||||
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Tests for SkipEvalOnResumeCallback."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkipEvalOnResumeCallback:
|
||||||
|
"""Tests for skipping redundant evaluation on checkpoint resume."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_state(global_step: int) -> TrainerState:
|
||||||
|
state = MagicMock(spec=TrainerState)
|
||||||
|
state.global_step = global_step
|
||||||
|
return state
|
||||||
|
|
||||||
|
def test_suppresses_eval_at_resume_step(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(20)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
# Simulate on_train_begin at checkpoint-20
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Trainer sets should_evaluate = True for step 20
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is False
|
||||||
|
|
||||||
|
def test_allows_eval_after_resume_step(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(20)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Advance past the resume point
|
||||||
|
state.global_step = 30
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is True
|
||||||
|
|
||||||
|
def test_noop_on_fresh_run(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(0)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
# Fresh run: global_step starts at 0
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Even if eval triggers at step 0 (unlikely but defensive)
|
||||||
|
state.global_step = 10
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is True
|
||||||
292
tests/utils/data/test_rl.py
Normal file
292
tests/utils/data/test_rl.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for RL data utility functions (excess_length_strategy support).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.utils.data.rl import (
|
||||||
|
_drop_long_sequences,
|
||||||
|
_raise_on_long_sequences,
|
||||||
|
_truncate_long_sequences_rl,
|
||||||
|
)
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTokenizer:
|
||||||
|
"""Simple whitespace tokenizer for testing length calculations."""
|
||||||
|
|
||||||
|
def __call__(self, text, add_special_tokens=True): # noqa: ARG002
|
||||||
|
tokens = text.split()
|
||||||
|
return {"input_ids": list(range(len(tokens)))}
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002
|
||||||
|
# Each token id maps to a placeholder word; length is what matters.
|
||||||
|
return " ".join(f"w{i}" for i in range(len(token_ids)))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int):
|
||||||
|
"""Create a DPO sample with specified word counts."""
|
||||||
|
return {
|
||||||
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
||||||
|
"chosen": " ".join(f"c{i}" for i in range(chosen_len)),
|
||||||
|
"rejected": " ".join(f"r{i}" for i in range(rejected_len)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_kto_sample(prompt_len: int, completion_len: int):
|
||||||
|
"""Create a KTO sample with specified word counts."""
|
||||||
|
return {
|
||||||
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
||||||
|
"completion": " ".join(f"c{i}" for i in range(completion_len)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDropLongSequences(unittest.TestCase):
|
||||||
|
"""Tests for the existing _drop_long_sequences filter function."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_dpo_keeps_short_samples(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_dpo_drops_long_chosen(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_dpo_drops_long_rejected(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_kto_keeps_short_samples(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_kto_drops_long_completion(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_grpo_always_keeps(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_dpo_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10)
|
||||||
|
|
||||||
|
def test_kto_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10)
|
||||||
|
|
||||||
|
def test_ipo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_orpo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.ORPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_boundary_length_kept(self):
|
||||||
|
"""Samples exactly at sequence_len should be kept."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRaiseOnLongSequences(unittest.TestCase):
|
||||||
|
"""Tests for _raise_on_long_sequences (excess_length_strategy='raise')."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_short_sample_passes(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _raise_on_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_long_sample_raises_valueerror(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
with self.assertRaises(ValueError, msg="excess_length_strategy"):
|
||||||
|
_raise_on_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_kto_long_raises(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_raise_on_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_grpo_never_raises(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _raise_on_long_sequences(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateLongSequencesRL(unittest.TestCase):
|
||||||
|
"""Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate')."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_dpo_short_sample_unchanged(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["chosen"], sample["chosen"])
|
||||||
|
self.assertEqual(result["rejected"], sample["rejected"])
|
||||||
|
|
||||||
|
def test_dpo_truncates_chosen(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
# max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5
|
||||||
|
chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
self.assertEqual(len(chosen_tokens), 5)
|
||||||
|
|
||||||
|
def test_dpo_truncates_rejected(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
self.assertEqual(len(rejected_tokens), 5)
|
||||||
|
|
||||||
|
def test_dpo_truncates_both(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
chosen_len = len(
|
||||||
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
rejected_len = len(
|
||||||
|
self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(chosen_len, 5)
|
||||||
|
self.assertEqual(rejected_len, 5)
|
||||||
|
|
||||||
|
def test_dpo_prompt_unchanged(self):
|
||||||
|
"""Prompt text should never be modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["prompt"], sample["prompt"])
|
||||||
|
|
||||||
|
def test_dpo_prompt_exceeds_limit_returns_unchanged(self):
|
||||||
|
"""When prompt alone exceeds sequence_len, sample is returned as-is."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_kto_truncates_completion(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
completion_len = len(
|
||||||
|
self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(completion_len, 5)
|
||||||
|
|
||||||
|
def test_kto_short_sample_unchanged(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["completion"], sample["completion"])
|
||||||
|
|
||||||
|
def test_kto_prompt_exceeds_limit_returns_unchanged(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=15, completion_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_grpo_unchanged(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_ipo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
chosen_len = len(
|
||||||
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(chosen_len, 5)
|
||||||
|
|
||||||
|
def test_does_not_mutate_original(self):
|
||||||
|
"""Verify immutability — original sample dict is not modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
original_chosen = sample["chosen"]
|
||||||
|
original_rejected = sample["rejected"]
|
||||||
|
_truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10)
|
||||||
|
self.assertEqual(sample["chosen"], original_chosen)
|
||||||
|
self.assertEqual(sample["rejected"], original_rejected)
|
||||||
|
|
||||||
|
def test_dpo_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_truncate_long_sequences_rl(
|
||||||
|
{"prompt": "hi"}, RLType.DPO, self.tokenizer, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_kto_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_truncate_long_sequences_rl(
|
||||||
|
{"prompt": "hi"}, RLType.KTO, self.tokenizer, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_boundary_no_truncation_needed(self):
|
||||||
|
"""Samples exactly at sequence_len should not be modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["chosen"], sample["chosen"])
|
||||||
|
self.assertEqual(result["rejected"], sample["rejected"])
|
||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -490,7 +491,8 @@ class TestEfficientMerge:
|
|||||||
out_features = 4
|
out_features = 4
|
||||||
alpha = 4
|
alpha = 4
|
||||||
|
|
||||||
base = torch.randn(num_experts, in_features, out_features)
|
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
|
||||||
|
base = torch.randn(num_experts, out_features, in_features)
|
||||||
lora_a = torch.randn(r * num_experts, in_features)
|
lora_a = torch.randn(r * num_experts, in_features)
|
||||||
lora_b = torch.randn(out_features, r * num_experts)
|
lora_b = torch.randn(out_features, r * num_experts)
|
||||||
|
|
||||||
@@ -506,7 +508,7 @@ class TestEfficientMerge:
|
|||||||
scale = alpha / r
|
scale = alpha / r
|
||||||
wa = lora_a.reshape(num_experts, r, in_features)
|
wa = lora_a.reshape(num_experts, r, in_features)
|
||||||
wb = lora_b.reshape(out_features, r, num_experts)
|
wb = lora_b.reshape(out_features, r, num_experts)
|
||||||
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
|
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
|
||||||
for e in range(num_experts):
|
for e in range(num_experts):
|
||||||
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
|
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
|
||||||
f"Expert {e} mismatch"
|
f"Expert {e} mismatch"
|
||||||
@@ -773,8 +775,8 @@ class TestEfficientMerge:
|
|||||||
"v_proj should be unchanged (no LoRA weights for it)"
|
"v_proj should be unchanged (no LoRA weights for it)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dora_missing_magnitude_falls_back(self):
|
def test_dora_missing_magnitude_raises(self):
|
||||||
"""DoRA without magnitude vector falls back to standard LoRA merge."""
|
"""DoRA with missing magnitude vector raises an explicit error."""
|
||||||
hidden = 16
|
hidden = 16
|
||||||
r = 4
|
r = 4
|
||||||
alpha = 8
|
alpha = 8
|
||||||
@@ -791,11 +793,13 @@ class TestEfficientMerge:
|
|||||||
}
|
}
|
||||||
|
|
||||||
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
|
||||||
merged, was_merged = _merge_tensor_with_lora(
|
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
|
||||||
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
|
_merge_tensor_with_lora(
|
||||||
)
|
base,
|
||||||
assert was_merged
|
"layer.proj.weight",
|
||||||
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
|
lora_state,
|
||||||
# which produces a result different from plain W + scale * B @ A.
|
scale,
|
||||||
# Just verify it was merged (not unchanged).
|
config,
|
||||||
assert not torch.equal(merged, base)
|
"cpu",
|
||||||
|
use_dora=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Covers:
|
|||||||
- save_strategy: 'best' requires metric_for_best_model
|
- save_strategy: 'best' requires metric_for_best_model
|
||||||
- streaming=True with val_set_size > 0 is rejected
|
- streaming=True with val_set_size > 0 is rejected
|
||||||
- lora_target_modules with invalid regex patterns is rejected
|
- lora_target_modules with invalid regex patterns is rejected
|
||||||
|
- GRPO: generation batch size must be divisible by num_generations,
|
||||||
|
num_generations >= 2, and effective_gbs >= num_generations * world_size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
|
|||||||
)
|
)
|
||||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGRPOBatchSizeValidator:
|
||||||
|
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
|
||||||
|
|
||||||
|
These call the @model_validator(mode="before") classmethod directly on a
|
||||||
|
plain dict — same input shape it receives during full Pydantic validation,
|
||||||
|
just without dragging in unrelated fields (datasets / model loading / etc.)
|
||||||
|
that aren't relevant to what's under test. The validator is registered on
|
||||||
|
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
|
||||||
|
same code path ``axolotl train`` exercises.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check(data):
|
||||||
|
from axolotl.utils.schemas.validation import RLValidationMixin
|
||||||
|
|
||||||
|
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
|
||||||
|
|
||||||
|
def test_divisible_passes(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
# Should return data unchanged (no exception)
|
||||||
|
out = self._check(data)
|
||||||
|
assert out["trl"]["num_generations"] == 4
|
||||||
|
|
||||||
|
def test_non_divisible_raises(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match="num_generations"):
|
||||||
|
self._check(data)
|
||||||
|
|
||||||
|
def test_non_divisible_error_includes_fix_hint(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 3,
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
|
||||||
|
self._check(data)
|
||||||
|
|
||||||
|
def test_num_generations_one_raises(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"trl": {"num_generations": 1},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match=r"num_generations >= 2"):
|
||||||
|
self._check(data)
|
||||||
|
|
||||||
|
def test_explicit_generation_batch_size_divisible_passes(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"trl": {"num_generations": 4, "generation_batch_size": 8},
|
||||||
|
}
|
||||||
|
out = self._check(data)
|
||||||
|
assert out["trl"]["generation_batch_size"] == 8
|
||||||
|
|
||||||
|
def test_explicit_generation_batch_size_non_divisible_raises(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"trl": {"num_generations": 4, "generation_batch_size": 6},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match="trl.generation_batch_size"):
|
||||||
|
self._check(data)
|
||||||
|
|
||||||
|
def test_non_grpo_skips_check(self):
|
||||||
|
# Anything other than rl=grpo should pass through untouched, even
|
||||||
|
# with non-divisible batch sizes — they're irrelevant to other RL
|
||||||
|
# methods that don't use group-relative advantages.
|
||||||
|
data = {
|
||||||
|
"rl": "dpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 3,
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
assert self._check(data) is data
|
||||||
|
|
||||||
|
def test_no_rl_set_skips_check(self):
|
||||||
|
data = {
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 3,
|
||||||
|
}
|
||||||
|
assert self._check(data) is data
|
||||||
|
|
||||||
|
def test_grpo_without_num_generations_skips_check(self):
|
||||||
|
# If num_generations isn't set, TRL uses its own default — we don't
|
||||||
|
# have enough info to validate, so the validator must short-circuit
|
||||||
|
# rather than guess.
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 3,
|
||||||
|
"trl": {},
|
||||||
|
}
|
||||||
|
out = self._check(data)
|
||||||
|
assert out["rl"] == "grpo"
|
||||||
|
|
||||||
|
def test_multi_rank_group_size_check(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4, # gbs=4
|
||||||
|
"world_size": 2, # need gbs >= 4*2 = 8
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match=r"world_size=2"):
|
||||||
|
self._check(data)
|
||||||
|
|
||||||
|
def test_multi_rank_group_size_satisfied(self):
|
||||||
|
data = {
|
||||||
|
"rl": "grpo",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
|
||||||
|
"world_size": 2,
|
||||||
|
"trl": {"num_generations": 4},
|
||||||
|
}
|
||||||
|
out = self._check(data)
|
||||||
|
assert out["gradient_accumulation_steps"] == 8
|
||||||
|
|||||||
Reference in New Issue
Block a user