Compare commits
1 Commits
weight-sca
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936149380f |
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,6 +30,14 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -152,6 +160,14 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,6 +18,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -180,6 +186,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
|
||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -33,6 +33,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
|
||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -15,6 +15,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -62,6 +67,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
|
||||
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
|
||||
31
.github/workflows/tests.yml
vendored
31
.github/workflows/tests.yml
vendored
@@ -68,11 +68,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12", "3.14"]
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -162,11 +164,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12", "3.14"]
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
@@ -325,6 +329,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
gpu_type: "B200"
|
||||
axolotl_extras: fbgemm-gpu
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
|
||||
@@ -11,7 +11,7 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.8
|
||||
rev: v0.15.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
94
AGENTS.md
94
AGENTS.md
@@ -1,94 +0,0 @@
|
||||
# Axolotl
|
||||
|
||||
Fine-tuning framework for LLMs. Config-driven: every training run is defined by a single YAML file.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
Python, PyTorch, HuggingFace Transformers, TRL, PEFT (LoRA/QLoRA), DeepSpeed, FSDP, vLLM (for GRPO generation).
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml # Train (single or multi-GPU, auto-detected)
|
||||
axolotl preprocess config.yaml # Tokenize dataset and validate config
|
||||
axolotl preprocess config.yaml --debug # Inspect tokenized samples and label masking
|
||||
axolotl inference config.yaml # Interactive inference
|
||||
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
||||
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
||||
axolotl fetch examples # Download example configs
|
||||
```
|
||||
|
||||
## Training Methods
|
||||
|
||||
| Method | Config Key | When to Use |
|
||||
|--------|-----------|-------------|
|
||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||
| EBFT | `rl: ebft` | Feature-matching rewards from internal representations |
|
||||
|
||||
Agent-specific references:
|
||||
- [docs/agents/sft.md](docs/agents/sft.md) — supervised fine-tuning
|
||||
- [docs/agents/preference_tuning.md](docs/agents/preference_tuning.md) — DPO, IPO, KTO, ORPO, SimPO
|
||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||
|
||||
## Config Pattern
|
||||
|
||||
All training is config-driven. A YAML file specifies model, adapter, dataset(s), and hyperparameters:
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.1-8B-Instruct
|
||||
adapter: lora # or qlora, or omit for full fine-tune
|
||||
datasets:
|
||||
- path: my_dataset
|
||||
type: chat_template # prompt strategy (see docs/dataset-formats/)
|
||||
output_dir: ./outputs/lora-out
|
||||
```
|
||||
|
||||
Config schema: `src/axolotl/utils/schemas/config.py` (AxolotlInputConfig).
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/ # CLI entry points (train, preprocess, inference, merge_lora, vllm_serve)
|
||||
core/
|
||||
builders/ # TrainerBuilder classes (causal.py for SFT, rl.py for RLHF)
|
||||
trainers/ # Trainer classes, mixins (optimizer, scheduler, packing)
|
||||
dpo/ # DPO trainer and config
|
||||
grpo/ # GRPO trainer and sampler
|
||||
loaders/ # Model, tokenizer, adapter, processor loading
|
||||
prompt_strategies/ # Dataset format handlers (chat_template, alpaca, dpo/, kto/, orpo/)
|
||||
utils/schemas/ # Pydantic config schemas (config, model, training, peft, trl, fsdp)
|
||||
integrations/ # Plugins (liger, cut_cross_entropy, swanlab, nemo_gym)
|
||||
monkeypatch/ # Runtime patches for HF transformers
|
||||
|
||||
examples/ # Example YAML configs by model (llama-3/, qwen2/, mistral/, ebft/)
|
||||
deepspeed_configs/ # DeepSpeed JSON configs (zero2, zero3)
|
||||
docs/ # Quarto documentation site
|
||||
```
|
||||
|
||||
## Code Conventions
|
||||
|
||||
- Config-driven: features are toggled via YAML, not code changes
|
||||
- Prompt strategies: `src/axolotl/prompt_strategies/` — each `type:` value maps to a function
|
||||
- Plugin system: `plugins:` list in config loads integration modules
|
||||
- Trainer mixins: `core/trainers/mixins/` for composable trainer behaviors
|
||||
- Schemas: all config validation via Pydantic in `utils/schemas/`
|
||||
|
||||
## Key Documentation
|
||||
|
||||
- [Getting Started](docs/getting-started.qmd) — quickstart tutorial
|
||||
- [Choosing a Method](docs/choosing_method.qmd) — SFT vs DPO vs GRPO decision guide
|
||||
- [Config Reference](docs/config-reference.qmd) — all config options
|
||||
- [Dataset Formats](docs/dataset-formats/) — chat_template, alpaca, input_output, completion
|
||||
- [RLHF](docs/rlhf.qmd) — DPO, KTO, ORPO, GRPO, EBFT configs and dataset formats
|
||||
- [GRPO Deep Dive](docs/grpo.qmd) — async training, custom rewards, scaling
|
||||
- [vLLM Serving](docs/vllm_serving.qmd) — vLLM setup for GRPO/EBFT
|
||||
- [Multi-GPU](docs/multi-gpu.qmd) — FSDP and DeepSpeed
|
||||
- [Training Stability](docs/training_stability.qmd) — debugging loss, NaN, OOM
|
||||
- [Debugging](docs/debugging.qmd) — VSCode setup, Docker debugging
|
||||
@@ -87,7 +87,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.9.1
|
||||
- PyTorch ≥2.8.0
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
@@ -238,7 +238,6 @@ website:
|
||||
- section: "Getting Started"
|
||||
contents:
|
||||
- docs/getting-started.qmd
|
||||
- docs/choosing_method.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- section: "Model Guides"
|
||||
@@ -303,9 +302,6 @@ website:
|
||||
contents:
|
||||
- docs/multimodal.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/grpo.qmd
|
||||
- docs/ebft.qmd
|
||||
- docs/vllm_serving.qmd
|
||||
- docs/reward_modelling.qmd
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
@@ -338,7 +334,6 @@ website:
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
- docs/faq.qmd
|
||||
- docs/training_stability.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
|
||||
12
cicd/cicd.sh
12
cicd/cicd.sh
@@ -4,17 +4,7 @@ set -e
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
set -o pipefail
|
||||
for i in 1 2 3; do
|
||||
if curl --silent --show-error --fail -L \
|
||||
https://axolotl-ci.b-cdn.net/hf-cache.tar.zst \
|
||||
| tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1; then
|
||||
echo "HF cache extracted successfully"
|
||||
break
|
||||
fi
|
||||
echo "Attempt $i failed, cleaning up and retrying in 15s..."
|
||||
rm -rf "${HF_HOME}/hub/"*
|
||||
sleep 15
|
||||
done
|
||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
|
||||
@@ -22,7 +22,6 @@ RUN apt update && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
@@ -36,22 +36,22 @@ RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
LINUX_TAG="manylinux_" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# GRPO — Agent Reference
|
||||
|
||||
Online RL with verifiable reward functions. For full config reference, async features, and scaling, see [grpo.qmd](../grpo.qmd). For vLLM setup, see [vllm_serving.qmd](../vllm_serving.qmd).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
|
||||
┌──────────────────────┐ ┌──────────────────────────────────┐
|
||||
│ vLLM Server │ HTTP │ Trainer │
|
||||
│ Serves base model │◄────────────►│ 1. Send prompts to vLLM │
|
||||
│ + LoRA adapter │ /generate │ 2. Score completions (rewards) │
|
||||
│ │ /set_lora │ 3. Compute advantages │
|
||||
│ Punica kernels for │ │ 4. PPO-clip gradient update │
|
||||
│ LoRA inference │ │ 5. Sync LoRA weights to vLLM │
|
||||
└──────────────────────┘ └──────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Components Required
|
||||
|
||||
1. A YAML config with `rl: grpo`
|
||||
2. A reward module (Python file with reward functions)
|
||||
3. A running vLLM server (`axolotl vllm-serve config.yaml`)
|
||||
|
||||
## Reward Function Signature
|
||||
|
||||
```python
|
||||
def my_reward(completions, **kwargs) -> list[float]:
|
||||
# completions[i][0]["content"] = text of i-th completion
|
||||
# **kwargs contains dataset columns not removed by transform
|
||||
return [score_for_each_completion]
|
||||
```
|
||||
|
||||
Multiple rewards: `reward_funcs: [r1, r2]` with `reward_weights: [1.0, 0.5]`.
|
||||
|
||||
## Key Async Features
|
||||
|
||||
| Feature | Config | Purpose |
|
||||
|---------|--------|---------|
|
||||
| Async prefetch | `async_prefetch: true` | Overlap generation with training |
|
||||
| LoRA sync | `vllm_lora_sync: true` | Fast adapter sync via filesystem |
|
||||
| Streaming scoring | `streaming_partial_batch: true` | Score one group at a time |
|
||||
| Zero-adv skip | `skip_zero_advantage_batches: true` | Skip batches with no learning signal |
|
||||
| Replay buffer | `replay_buffer_size: 100` | Cache high-signal groups |
|
||||
| IS correction | `vllm_importance_sampling_correction: true` | Fix off-policy distribution shift |
|
||||
|
||||
## Health Checks
|
||||
|
||||
- `rewards/*/mean` > 0.15 within 20 steps (else: test reward function standalone)
|
||||
- `reward_std` > 0 on most steps (else: no learning signal)
|
||||
- `entropy` 0.05-0.5 (< 0.01 = mode collapse)
|
||||
- `grad_norm` 0.001-1.0 (> 10 = unstable, 0.0 = zero-advantage skip)
|
||||
|
||||
See [training_stability.qmd](../training_stability.qmd) for detailed diagnostics.
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/train.py # Entry point
|
||||
cli/vllm_serve.py # Entry point for vLLM server
|
||||
core/trainers/grpo/
|
||||
trainer.py # AxolotlGRPOTrainer
|
||||
sampler.py # Sampling utilities
|
||||
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer
|
||||
scripts/vllm_serve_lora.py # vLLM serve script with LoRA sync support
|
||||
utils/schemas/trl.py # TRL config schema (all trl: options)
|
||||
|
||||
docs/grpo.qmd # Full user docs: async, rewards, scaling, config reference
|
||||
docs/vllm_serving.qmd # vLLM server modes, LoRA sync, weight sync
|
||||
```
|
||||
@@ -1,121 +0,0 @@
|
||||
# Preference Learning (RLHF) — Agent Reference
|
||||
|
||||
Reference for DPO, IPO, KTO, ORPO, and SimPO. For config templates and dataset format examples, see [rlhf.qmd](../rlhf.qmd). For GRPO, see [grpo.qmd](../grpo.qmd). For EBFT, see [ebft.qmd](../ebft.qmd).
|
||||
|
||||
## Method Overview
|
||||
|
||||
| Method | Data Requirement | Key Idea | Best For |
|
||||
|--------|-----------------|----------|----------|
|
||||
| **DPO** | Paired (chosen + rejected) | Implicit reward via preference pairs | General alignment, most common |
|
||||
| **IPO** | Paired (chosen + rejected) | DPO with different loss (avoids overfitting) | When DPO overfits |
|
||||
| **KTO** | Unpaired (completion + binary label) | Kahneman-Tversky loss, no pairs needed | When you only have thumbs-up/down |
|
||||
| **ORPO** | Paired (chosen + rejected) | Combined SFT + preference, no ref model | Single-stage alignment, saves VRAM |
|
||||
| **SimPO** | Paired (chosen + rejected) | Length-normalized, no ref model | Simple setup, length-robust |
|
||||
|
||||
Default: start with DPO. All methods require `sample_packing: false`.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌───────────────┐ ┌───────────────┐
|
||||
│ Policy Model │ │ Reference │ │ Preference │
|
||||
│ (trainable) │ │ Model (frozen)│ │ Dataset │
|
||||
└──────┬───────┘ └──────┬────────┘ └──────┬────────┘
|
||||
└──────────┬───────┘ │
|
||||
v │
|
||||
Forward pass on chosen + rejected <─────┘
|
||||
│
|
||||
Preference Loss (DPO/IPO/KTO/...)
|
||||
│
|
||||
Backprop + Update
|
||||
|
||||
Exception: ORPO and SimPO do NOT use a reference model (~50% less VRAM).
|
||||
```
|
||||
|
||||
No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference data.
|
||||
|
||||
## Method Selection
|
||||
|
||||
1. Paired preference data (chosen + rejected)?
|
||||
- Default → `rl: dpo`
|
||||
- Overfitting → `rl: ipo`
|
||||
- VRAM-limited → `rl: orpo` (no ref model)
|
||||
- Length-sensitive → `rl: simpo` (no ref model)
|
||||
2. Only binary labels (good/bad)? → `rl: kto`
|
||||
3. Single-stage training (no separate SFT)? → `rl: orpo`
|
||||
|
||||
| | DPO | IPO | KTO | ORPO | SimPO |
|
||||
|---|---|---|---|---|---|
|
||||
| **Reference model** | Yes | Yes | Yes | No | No |
|
||||
| **VRAM overhead** | ~2x model | ~2x model | ~2x model | ~1x model | ~1x model |
|
||||
| **TRL trainer class** | DPOTrainer | DPOTrainer | KTOTrainer | ORPOTrainer | CPOTrainer |
|
||||
|
||||
## Prompt Strategy Resolution
|
||||
|
||||
The `type` field resolves to a Python function:
|
||||
|
||||
```
|
||||
type: "chatml.intel"
|
||||
→ axolotl.prompt_strategies.dpo.chatml.intel(cfg, **kwargs)
|
||||
→ returns transform_fn(sample) → {"prompt", "chosen", "rejected"}
|
||||
|
||||
type: "chat_template.default"
|
||||
→ axolotl.prompt_strategies.dpo.chat_template.default(cfg, dataset_idx, **kwargs)
|
||||
|
||||
type: {"field_prompt": "prompt", ...} (dict)
|
||||
→ axolotl.prompt_strategies.dpo.user_defined.default(...)
|
||||
```
|
||||
|
||||
Module base: `axolotl.prompt_strategies.{rl_method}` — replace `dpo` with `kto` or `orpo`.
|
||||
|
||||
## Healthy Training Indicators
|
||||
|
||||
| Metric | Healthy Range | Problem |
|
||||
|--------|--------------|---------|
|
||||
| `train/loss` | Decreasing, 0.3-0.7 | Flat or increasing = broken data or too high LR |
|
||||
| `rewards/chosen` | Increasing | Flat = model not learning preferences |
|
||||
| `rewards/rejected` | Decreasing | Increasing = model prefers wrong responses |
|
||||
| `rewards/margins` | Positive and increasing | Negative = prefers rejected over chosen |
|
||||
| `rewards/accuracies` | > 0.5, toward 0.7+ | < 0.5 = worse than random |
|
||||
| `logps/rejected` | Decreasing | Increasing = reward hacking |
|
||||
| `grad_norm` | 0.01 - 10.0 | > 100 = exploding gradients |
|
||||
|
||||
Method-specific: DPO/IPO watch `rewards/margins`; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation.
|
||||
|
||||
## Known Issues
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| Sample packing crash | Set `sample_packing: false` (required for all preference methods) |
|
||||
| KTO `KeyError: 'label'` | Ensure dataset has boolean `label` column |
|
||||
| ORPO/KTO `KeyError` during tokenization | Add `remove_unused_columns: false` |
|
||||
| ORPO template not applied | ORPO requires explicit `chat_template` setting |
|
||||
| OOM with ref model (DPO/IPO/KTO) | Use LoRA/QLoRA, or switch to ORPO/SimPO (no ref model) |
|
||||
| IPO + label_smoothing | Do not set `dpo_label_smoothing` when `rl: ipo` |
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
core/trainers/dpo/ # DPO trainer, args, strategy
|
||||
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer class
|
||||
core/training_args.py # AxolotlKTOConfig, AxolotlORPOConfig, AxolotlCPOConfig
|
||||
prompt_strategies/
|
||||
dpo/ # DPO/IPO/SimPO dataset strategies
|
||||
chat_template.py # chat_template.default, chat_template.argilla_chat
|
||||
chatml.py # chatml.default/intel/icr/argilla_chat/prompt_pairs/ultra
|
||||
llama3.py # llama3 variants (same subtypes as chatml)
|
||||
user_defined.py # Custom field mapping
|
||||
passthrough.py # No transform
|
||||
kto/ # KTO dataset strategies (chatml, llama3, user_defined)
|
||||
orpo/ # ORPO dataset strategies (chat_template.argilla)
|
||||
utils/schemas/enums.py # RLType enum (dpo, ipo, kto, orpo, simpo, grpo, gdpo, ebft)
|
||||
utils/schemas/config.py # All rl/dpo/kto/orpo/simpo config fields
|
||||
|
||||
docs/rlhf.qmd # Full user docs: all dataset formats, config templates
|
||||
docs/choosing_method.qmd # SFT vs DPO vs GRPO decision guide
|
||||
examples/qwen2/dpo.yaml # DPO example
|
||||
examples/llama-3/qlora-1b-kto.yaml # KTO example
|
||||
```
|
||||
@@ -1,75 +0,0 @@
|
||||
# Pretraining / Continual Pretraining — Agent Reference
|
||||
|
||||
Train on raw text with no input masking. Two approaches depending on dataset size.
|
||||
|
||||
## When to Use
|
||||
|
||||
- Continual pretraining on domain-specific corpora
|
||||
- Adapting a base model to a new language or domain before fine-tuning
|
||||
- Pretraining-style data where the entire text is the training signal
|
||||
|
||||
## Choosing an Approach
|
||||
|
||||
| | Non-streaming (`type: completion`) | Streaming (`pretraining_dataset`) |
|
||||
|---|---|---|
|
||||
| **Dataset size** | Fits in memory | Too large to fit in memory |
|
||||
| **Tokenization** | Pre-tokenized before training | On-demand during training |
|
||||
| **Config key** | `datasets:` | `pretraining_dataset:` |
|
||||
| **Long text handling** | Splits texts exceeding `sequence_len` | Concatenates into fixed-length sequences |
|
||||
| **Benefit** | Can preprocess on CPU, transfer to GPU | Start training immediately, no preprocessing |
|
||||
|
||||
## Non-Streaming: `type: completion`
|
||||
|
||||
For smaller datasets that fit in memory. Pre-tokenizes the entire dataset.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: my_corpus
|
||||
type: completion
|
||||
# field: text # Column name (default: "text")
|
||||
```
|
||||
|
||||
## Streaming: `pretraining_dataset`
|
||||
|
||||
For large corpora. Streams data on-demand without loading everything into memory.
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: HuggingFaceFW/fineweb-edu
|
||||
type: pretrain
|
||||
text_column: text
|
||||
split: train
|
||||
|
||||
max_steps: 1000 # Required — axolotl can't infer dataset size
|
||||
streaming_multipack_buffer_size: 10000 # Buffer for sample packing
|
||||
pretrain_multipack_attn: true # Prevent cross-attention between packed samples
|
||||
```
|
||||
|
||||
`max_steps` is required for streaming — one step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
|
||||
|
||||
Full streaming docs: [streaming.qmd](../streaming.qmd)
|
||||
|
||||
## Dataset Format
|
||||
|
||||
```json
|
||||
{"text": "The complete document text goes here."}
|
||||
```
|
||||
|
||||
## Key Settings
|
||||
|
||||
- `sample_packing: true` + `pad_to_sequence_len: true` — pack documents into fixed-length sequences
|
||||
- `flash_attention: true` — required for sample packing
|
||||
- No adapter — typically full fine-tune for pretraining
|
||||
- `train_on_inputs: true` — default for completion (all tokens trained on)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
prompt_strategies/completion.py # Non-streaming: completion prompt strategy (no masking)
|
||||
utils/data/sft.py # Non-streaming: dataset loading and processing
|
||||
utils/data/streaming.py # Streaming: encode_streaming(), wrap_streaming_dataset()
|
||||
utils/schemas/config.py # Config fields: pretraining_dataset, pretrain_multipack_attn, etc.
|
||||
|
||||
examples/streaming/pretrain.yaml # Full streaming pretraining example config
|
||||
```
|
||||
@@ -1,48 +0,0 @@
|
||||
# Reward Modelling — Agent Reference
|
||||
|
||||
Train models to score responses for use as reward signals in RL. For full docs, see [reward_modelling.qmd](../reward_modelling.qmd).
|
||||
|
||||
## Types
|
||||
|
||||
### Outcome Reward Models (ORM)
|
||||
|
||||
Train a classifier to predict preference over entire interactions. Uses `AutoModelForSequenceClassification`.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
```
|
||||
|
||||
Dataset format: `{"system": "...", "input": "...", "chosen": "...", "rejected": "..."}`
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
Train a token classifier to score each reasoning step. Uses `AutoModelForTokenClassification`.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
model_type: AutoModelForTokenClassification
|
||||
num_labels: 2
|
||||
process_reward_model: true
|
||||
datasets:
|
||||
- path: trl-lib/math_shepherd
|
||||
type: stepwise_supervised
|
||||
```
|
||||
|
||||
Dataset format: see [stepwise_supervised.qmd](../dataset-formats/stepwise_supervised.qmd).
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
core/builders/causal.py # Handles reward_model flag in trainer builder
|
||||
prompt_strategies/bradley_terry/ # Bradley-Terry prompt strategies
|
||||
prompt_strategies/stepwise_supervised.py # PRM dataset strategy
|
||||
utils/schemas/config.py # reward_model, process_reward_model config fields
|
||||
```
|
||||
@@ -1,115 +0,0 @@
|
||||
# SFT — Agent Reference
|
||||
|
||||
Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see [getting-started.qmd](../getting-started.qmd) and [dataset-formats/](../dataset-formats/).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
YAML Config → axolotl train config.yaml
|
||||
|
||||
1. Load base model (+ quantization if QLoRA/8-bit)
|
||||
2. Apply adapter layers (LoRA/QLoRA) if configured
|
||||
3. Load + tokenize dataset(s)
|
||||
- Apply prompt template (chat_template / alpaca / custom)
|
||||
- Mask inputs (train_on_inputs: false)
|
||||
- Pack samples into sequences (sample_packing: true)
|
||||
4. Training loop (HuggingFace Trainer)
|
||||
- forward → loss → backward → optimizer step → lr scheduler step
|
||||
5. Save model / adapter weights + tokenizer
|
||||
|
||||
Multi-GPU: FSDP or DeepSpeed shards model across GPUs automatically.
|
||||
```
|
||||
|
||||
## Components Required
|
||||
|
||||
1. A YAML config — model, dataset(s), adapter settings, hyperparameters
|
||||
2. A dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path
|
||||
3. (Optional) A custom prompt strategy — for non-standard dataset formats
|
||||
|
||||
No external server processes needed (unlike GRPO which requires vLLM).
|
||||
|
||||
## Dataset Format Decision Tree
|
||||
|
||||
```
|
||||
Is your data in chat/message format?
|
||||
├─ YES: OpenAI message format (role/content)?
|
||||
│ ├─ YES ──────────────────────> type: chat_template (recommended)
|
||||
│ └─ NO (custom field names) ──> type: chat_template + message_property_mappings
|
||||
└─ NO: Instruction/response pairs?
|
||||
├─ YES ──> type: alpaca (instruction, input, output)
|
||||
└─ NO: Raw text?
|
||||
├─ YES with segments ─────> type: input_output (template-free masking)
|
||||
└─ YES continuous ────────> type: completion (pretraining-style)
|
||||
```
|
||||
|
||||
Full format specs: [dataset-formats/](../dataset-formats/)
|
||||
|
||||
## Model Size to Adapter Choice
|
||||
|
||||
| Model Size | LoRA | QLoRA (4-bit) | Full Fine-Tune | VRAM (approx) |
|
||||
|-----------|------|---------------|----------------|---------------|
|
||||
| 1-3B | Preferred | Low-budget option | Single GPU OK | 8-16 GB (LoRA) |
|
||||
| 7-8B | Preferred | Good balance | Needs multi-GPU | 16-24 GB (LoRA) |
|
||||
| 13-14B | Preferred | Good balance | Multi-GPU required | 24-40 GB (LoRA) |
|
||||
| 30-70B | LoRA or QLoRA | Preferred for single GPU | Multi-node | 40-80 GB (QLoRA) |
|
||||
|
||||
## Hyperparameter Ranges
|
||||
|
||||
| Parameter | LoRA | QLoRA | Full FT |
|
||||
|-----------|------|-------|---------|
|
||||
| `learning_rate` | 1e-4 to 3e-4 | 1e-4 to 3e-4 | 1e-5 to 5e-5 |
|
||||
| `lora_r` | 16-64 | 16-64 | N/A |
|
||||
| `lora_alpha` | 1-2x `lora_r` | 1-2x `lora_r` | N/A |
|
||||
| `micro_batch_size` | 2-8 | 2-4 | 1-2 |
|
||||
| `gradient_accumulation_steps` | 2-8 | 4-16 | 4-16 |
|
||||
| `num_epochs` | 1-3 | 1-3 | 1-3 |
|
||||
| `optimizer` | `adamw_8bit` | `adamw_bnb_8bit` | `adamw_torch_fused` |
|
||||
|
||||
Effective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.
|
||||
|
||||
## Healthy Training Indicators
|
||||
|
||||
| Metric | Healthy | Problem |
|
||||
|--------|---------|---------|
|
||||
| `train_loss` | Decreasing, starting ~2-4 for chat models | Flat or increasing from step 1 — data or LR issue |
|
||||
| `eval_loss` | Decreasing, tracks train_loss | Increasing while train_loss decreases — overfitting |
|
||||
| `grad_norm` | 0.1-10, relatively stable | Spikes >100 — instability. 0.0 — frozen weights |
|
||||
| `learning_rate` | Follows scheduler curve | Flat or NaN — config issue |
|
||||
|
||||
Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See [training_stability.qmd](../training_stability.qmd).
|
||||
|
||||
## Known Issues
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||
| Tokenizer pad token / infinite loss | Set `special_tokens: pad_token: "<\|end_of_text\|>"` |
|
||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/train.py # Entry point for `axolotl train`
|
||||
cli/preprocess.py # Entry point for `axolotl preprocess`
|
||||
core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer
|
||||
core/trainers/base.py # AxolotlTrainer — base trainer class
|
||||
core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints
|
||||
prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output
|
||||
utils/schemas/config.py # AxolotlInputConfig — main config schema
|
||||
utils/schemas/datasets.py # SFTDataset, DatasetConfig
|
||||
utils/schemas/peft.py # LoraConfig — LoRA parameters
|
||||
integrations/liger/ # Liger kernel plugin
|
||||
|
||||
examples/llama-3/ # LoRA, QLoRA, full FT example configs
|
||||
docs/getting-started.qmd # Quickstart with config templates
|
||||
docs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing
|
||||
docs/multi-gpu.qmd # FSDP and DeepSpeed setup
|
||||
```
|
||||
@@ -1,206 +0,0 @@
|
||||
---
|
||||
title: "Which Fine-Tuning Method Should I Use?"
|
||||
description: "A decision guide for choosing the right fine-tuning method, adapter, and hardware configuration in Axolotl."
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview {#sec-overview}
|
||||
|
||||
Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.
|
||||
|
||||
| Method | What It Does | Data You Need |
|
||||
|--------|-------------|---------------|
|
||||
| **Supervised Fine-Tuning (SFT)** | Teaches the model to produce specific outputs given inputs | Input-output pairs (instructions, conversations, completions) |
|
||||
| **Preference Learning (DPO/KTO/ORPO)** | Steers the model toward preferred outputs and away from dispreferred ones | Chosen/rejected response pairs (DPO, ORPO) or binary labels (KTO) |
|
||||
| **Reinforcement Learning (GRPO)** | Optimizes the model against a reward signal through online generation | A reward function (code or model-based) and a prompt dataset |
|
||||
| **Reward Modeling** | Trains a model to score responses, for use as a reward signal in RL | Preference pairs ranked by quality |
|
||||
|
||||
Each method is configured through a YAML file with `rl: <method>` (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.
|
||||
|
||||
## Decision Tree {#sec-decision-tree}
|
||||
|
||||
Use the following flowchart to choose your method. Start at the top and follow the path that matches your situation.
|
||||
|
||||
```
|
||||
Do you have a reward function (code-based or model-based)?
|
||||
├── YES
|
||||
│ └── Use GRPO (rl: grpo)
|
||||
│ The model generates its own completions and learns from reward scores.
|
||||
│ Best for: math, code, reasoning, tasks with verifiable answers.
|
||||
│ See: rlhf.qmd#grpo
|
||||
│
|
||||
└── NO
|
||||
│
|
||||
Do you have preference pairs (chosen vs. rejected responses)?
|
||||
├── YES
|
||||
│ │
|
||||
│ Are they paired (same prompt, one chosen, one rejected)?
|
||||
│ ├── YES → Use DPO (rl: dpo)
|
||||
│ │ Direct optimization without a separate reward model.
|
||||
│ │ See: rlhf.qmd#dpo
|
||||
│ │
|
||||
│ └── NO (only binary good/bad labels)
|
||||
│ └── Use KTO (rl: kto)
|
||||
│ Works with unpaired preference data.
|
||||
│ See: rlhf.qmd#kto
|
||||
│
|
||||
└── NO
|
||||
│
|
||||
Do you have input-output examples?
|
||||
├── YES → Use SFT
|
||||
│ The simplest and most common method.
|
||||
│ See: getting-started.qmd
|
||||
│
|
||||
└── NO
|
||||
└── You need to create training data first.
|
||||
Consider generating preference pairs with an LLM judge,
|
||||
or writing a reward function for GRPO.
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
**When in doubt, start with SFT.** It is the most straightforward method and works well for most tasks. You can always move to preference learning or RL later to further refine behavior.
|
||||
:::
|
||||
|
||||
### Method Comparison at a Glance
|
||||
|
||||
| Criterion | SFT | DPO | KTO | GRPO |
|
||||
|-----------|-----|-----|-----|------|
|
||||
| Data complexity | Low (input-output pairs) | Medium (preference pairs) | Medium (binary labels) | Low (prompts + reward code) |
|
||||
| Compute cost | Low | Medium | Medium | High (requires vLLM server) |
|
||||
| Learning signal | Supervised | Contrastive | Contrastive | Online reward |
|
||||
| Online generation | No | No | No | Yes |
|
||||
| Reward model needed | No | No | No | No (uses reward functions) |
|
||||
| Best for | Task adaptation, instruction following | Safety, style alignment | Unpaired preference data | Reasoning, math, code |
|
||||
|
||||
::: {.callout-note}
|
||||
**ORPO** is an alternative to DPO that combines SFT and preference optimization in a single training stage, removing the need for a separate SFT step. Configure with `rl: orpo`. See [rlhf.qmd](rlhf.qmd) for details.
|
||||
:::
|
||||
|
||||
## Adapter Selection {#sec-adapter-selection}
|
||||
|
||||
Once you have chosen a method, decide how to apply the parameter updates. The three main options trade off VRAM usage against model quality.
|
||||
|
||||
### QLoRA
|
||||
|
||||
- **How it works**: The base model is loaded in 4-bit (NF4) quantization. Small low-rank adapter matrices are trained in higher precision on top.
|
||||
- **VRAM savings**: Roughly 4x reduction in model memory compared to full fine-tuning.
|
||||
- **Quality**: Slight degradation due to quantization noise, but often negligible for task-specific fine-tuning.
|
||||
- **When to use**: When your GPU cannot fit the model in full precision, or when you want fast experimentation.
|
||||
|
||||
```yaml
|
||||
adapter: qlora
|
||||
load_in_4bit: true
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
### LoRA
|
||||
|
||||
- **How it works**: The base model is loaded at full precision (or 8-bit). Low-rank adapter matrices are trained alongside.
|
||||
- **VRAM savings**: Roughly 2-3x reduction compared to full fine-tuning (model weights are frozen, only adapters + optimizer states for adapters are stored).
|
||||
- **Quality**: Very close to full fine-tuning for most tasks, especially with higher rank values.
|
||||
- **When to use**: When you have enough VRAM for the base model but not for full optimizer states.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
For GRPO training, LoRA is strongly recommended. The vLLM server needs to sync weights from the trainer, and LoRA sync (`trl.vllm_lora_sync: true`) is far more efficient than syncing full merged weights. See [vLLM Serving](vllm_serving.qmd) for details.
|
||||
:::
|
||||
|
||||
### Full Fine-Tuning
|
||||
|
||||
- **How it works**: All model parameters are updated during training. No adapters.
|
||||
- **VRAM savings**: None. Requires memory for model weights, gradients, and optimizer states (roughly 4x model size in bf16 with AdamW).
|
||||
- **Quality**: Highest potential quality, especially for large distribution shifts.
|
||||
- **When to use**: When you have ample GPU memory or multi-GPU setups, and need maximum performance. Also required for pre-training.
|
||||
|
||||
```yaml
|
||||
# No adapter or load_in_* lines needed
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
### Quick Comparison
|
||||
|
||||
| | QLoRA | LoRA | Full |
|
||||
|---|---|---|---|
|
||||
| Trainable params | ~0.1-1% | ~0.1-1% | 100% |
|
||||
| Model memory | ~25% of full | ~50-100% of full | 100% |
|
||||
| Optimizer memory | Tiny (adapters only) | Tiny (adapters only) | 2x model size (AdamW) |
|
||||
| Training speed | Slower (dequantization overhead) | Baseline | Faster per-step (no adapter overhead) |
|
||||
| Inference | Merge or serve with adapter | Merge or serve with adapter | Direct |
|
||||
| Multi-GPU required? | Rarely | For 13B+ models | For 7B+ models |
|
||||
|
||||
## Hardware Mapping {#sec-hardware-mapping}
|
||||
|
||||
The tables below provide approximate GPU memory requirements. Actual usage depends on context length, batch size, and optimizer choice.
|
||||
|
||||
### SFT / Preference Learning
|
||||
|
||||
| Model Size | QLoRA (4-bit) | LoRA (bf16) | Full (bf16 + AdamW) |
|
||||
|------------|--------------|-------------|---------------------|
|
||||
| 1-3B | 6-8 GB | 8-12 GB | 24-32 GB |
|
||||
| 7-8B | 10-14 GB | 16-24 GB | 60-80 GB |
|
||||
| 13-14B | 16-20 GB | 28-40 GB | 120+ GB |
|
||||
| 30-34B | 24-32 GB | 64-80 GB | 2-4x 80 GB |
|
||||
| 70-72B | 40-48 GB | 2x 80 GB | 4-8x 80 GB |
|
||||
|
||||
::: {.callout-important}
|
||||
These estimates assume a short context length (512-2048 tokens) and micro_batch_size of 1-2. Longer sequences and larger batches increase memory significantly due to activations. Use [gradient checkpointing](gradient_checkpointing.qmd) to reduce activation memory at the cost of ~30% slower training.
|
||||
:::
|
||||
|
||||
### GRPO (RL Training)
|
||||
|
||||
GRPO requires additional GPU(s) for the vLLM generation server. Plan for at least two GPUs: one for training, one for vLLM.
|
||||
|
||||
| Model Size | Training GPU (LoRA, bf16) | vLLM GPU | Total GPUs |
|
||||
|------------|--------------------------|----------|------------|
|
||||
| 0.5-3B | 1x 24 GB | 1x 24 GB | 2x 24 GB |
|
||||
| 7-8B | 1x 80 GB | 1x 80 GB | 2x 80 GB |
|
||||
| 13-14B | 1-2x 80 GB | 1-2x 80 GB | 2-4x 80 GB |
|
||||
| 30-72B | 2-4x 80 GB (FSDP/DeepSpeed) | 2-4x 80 GB (tensor parallel) | 4-8x 80 GB |
|
||||
|
||||
::: {.callout-tip}
|
||||
For single-GPU GRPO, use `vllm_mode: colocate` with `vllm_enable_sleep_mode: true`. The vLLM engine shares the GPU and offloads VRAM when not generating. This works for smaller models (up to ~3B on a 24 GB GPU) but is slower than the two-GPU server mode.
|
||||
:::
|
||||
|
||||
### Multi-GPU Threshold
|
||||
|
||||
You need multi-GPU training when:
|
||||
|
||||
- **Full fine-tuning** of models 7B+ (use FSDP or DeepSpeed ZeRO)
|
||||
- **LoRA** of models 30B+ (or 13B+ with long contexts)
|
||||
- **GRPO** almost always (separate vLLM server), unless using colocate mode
|
||||
|
||||
See [Multi-GPU Training](multi-gpu.qmd) for FSDP and DeepSpeed configuration.
|
||||
|
||||
## Quick Links {#sec-quick-links}
|
||||
|
||||
| Method | Config Key | Documentation | Example Config |
|
||||
|--------|-----------|---------------|----------------|
|
||||
| SFT | *(default, no `rl:` key)* | [Getting Started](getting-started.qmd) | `examples/llama-3/lora-1b.yml` |
|
||||
| DPO | `rl: dpo` | [RLHF - DPO](rlhf.qmd#dpo) | See rlhf.qmd |
|
||||
| KTO | `rl: kto` | [RLHF - KTO](rlhf.qmd#kto) | See rlhf.qmd |
|
||||
| ORPO | `rl: orpo` | [RLHF - ORPO](rlhf.qmd#orpo) | See rlhf.qmd |
|
||||
| GRPO | `rl: grpo` | [RLHF - GRPO](rlhf.qmd#grpo), [vLLM Serving](vllm_serving.qmd) | See rlhf.qmd |
|
||||
| Reward Modeling | `rl: reward_trainer` | [Reward Modelling](reward_modelling.qmd) | See reward_modelling.qmd |
|
||||
|
||||
### Related Guides
|
||||
|
||||
- [Configuration Reference](config-reference.qmd) -- Full list of all config options
|
||||
- [Dataset Formats](dataset-formats) -- How to structure your training data
|
||||
- [Optimizations](optimizations.qmd) -- Flash attention, gradient checkpointing, mixed precision
|
||||
- [Multi-GPU Training](multi-gpu.qmd) -- FSDP and DeepSpeed setup
|
||||
- [vLLM Serving](vllm_serving.qmd) -- Setting up vLLM for GRPO training
|
||||
@@ -22,47 +22,90 @@ For `pretraining_dataset:` specifically, please refer to the [Pre-training secti
|
||||
|
||||
## Pre-training
|
||||
|
||||
Pre-training trains on raw text corpora with no input masking. The dataset format is simple:
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
A sample format for a pre-training dataset is as follows:
|
||||
|
||||
```json
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
Axolotl supports two approaches:
|
||||
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
|
||||
|
||||
### Streaming (large datasets)
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
For large corpora that don't fit in memory, use `pretraining_dataset` with [streaming](../streaming.qmd). Data is tokenized on-demand during training.
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: hf_org/name
|
||||
```
|
||||
|
||||
### Pre-training from local dataset files
|
||||
|
||||
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: HuggingFaceFW/fineweb-edu
|
||||
type: pretrain
|
||||
text_column: text
|
||||
split: train
|
||||
- path: json
|
||||
data_files:
|
||||
- A.jsonl
|
||||
- B.jsonl
|
||||
- C.jsonl
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Streaming requires `max_steps` in your config — Axolotl cannot infer the dataset size. One step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
|
||||
:::
|
||||
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
|
||||
|
||||
See [Streaming Datasets](../streaming.qmd) for full configuration details.
|
||||
### Pre-training without streaming
|
||||
|
||||
### Non-streaming (smaller datasets)
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
For datasets that fit in memory, use `type: completion` under `datasets:`. The entire dataset is pre-tokenized before training, which can be done on a CPU-only machine.
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
From Hugging Face:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: my_corpus
|
||||
- path: hf_org/name
|
||||
type: completion
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
With `completion`, texts exceeding `sequence_len` are split into multiple samples automatically.
|
||||
From local files:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: B.jsonl
|
||||
type: completion
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
|
||||
:::
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
|
||||
|
||||
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
|
||||
|
||||
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
|
||||
|
||||
#### Group_by_length
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
### Reference
|
||||
|
||||
Please see docs [here](pretraining.qmd).
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
|
||||
@@ -4,9 +4,29 @@ description: Data format for a pre-training completion task.
|
||||
order: 1
|
||||
---
|
||||
|
||||
::: {.callout-note}
|
||||
Pre-training documentation has been consolidated:
|
||||
For pretraining, there is no prompt template or roles. The only required field is `text`:
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
:::{.callout-note}
|
||||
|
||||
### Streaming is recommended for large datasets
|
||||
|
||||
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
||||
|
||||
```{.yaml filename="config.yaml"}
|
||||
pretraining_dataset:
|
||||
- name:
|
||||
path:
|
||||
split:
|
||||
text_column: # column in dataset with the data, usually `text`
|
||||
type: pretrain
|
||||
trust_remote_code:
|
||||
skip: # number of rows of data to skip over from the beginning
|
||||
```
|
||||
|
||||
- **Streaming pretraining** (large datasets): See [Streaming Datasets](../streaming.qmd#pretraining-with-streaming)
|
||||
- **Non-streaming pretraining** (`type: completion`): See [Dataset Formats](index.qmd#pre-training)
|
||||
:::
|
||||
|
||||
@@ -6,10 +6,6 @@ description: How to debug Axolotl
|
||||
|
||||
This document provides some tips and tricks for debugging Axolotl. It also provides an example configuration for debugging with VSCode. A good debugging setup is essential to understanding how Axolotl code works behind the scenes.
|
||||
|
||||
::: {.callout-tip}
|
||||
For training-specific debugging (loss spikes, NaN gradients, OOM errors, RL training stability), see [Training Stability & Debugging](training_stability.qmd).
|
||||
:::
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [General Tips](#general-tips)
|
||||
@@ -89,7 +85,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```json
|
||||
// .vscode/launch.json
|
||||
@@ -246,6 +242,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The VSCode config uses `accelerate.commands.launch` as the Python module entry point, which is what `axolotl train` invokes under the hood.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
556
docs/ebft.qmd
556
docs/ebft.qmd
@@ -1,556 +0,0 @@
|
||||
---
|
||||
title: "EBFT Training"
|
||||
description: "Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions."
|
||||
order: 9
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the **internal feature representations** of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
|
||||
|
||||
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
||||
|
||||
### How EBFT Differs from Other RL Methods
|
||||
|
||||
| Method | Reward Signal | Requires | Best For |
|
||||
|--------|--------------|----------|----------|
|
||||
| **GRPO** | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
|
||||
| **DPO** | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
|
||||
| **EBFT** | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
|
||||
|
||||
EBFT's key advantage is that it needs only ground-truth completions -- no reward engineering, no preference annotation, and no reward model training. The model's own internal representations serve as the reward signal. This makes it particularly effective for:
|
||||
|
||||
- Code generation (match features of known-good solutions)
|
||||
- Instruction following with reference outputs
|
||||
- Continual pretraining on unstructured text (strided mode)
|
||||
- Multi-turn dialogue with reference conversations
|
||||
|
||||
### Reward Formulation
|
||||
|
||||
The EBFT reward for each generated completion is:
|
||||
|
||||
```
|
||||
reward = alignment_coef * cosine_similarity(gen_features, gt_features)
|
||||
- diversity_coef * mean_pairwise_similarity(gen_features)
|
||||
```
|
||||
|
||||
- **Alignment**: How closely the generated output's internal representations match the ground truth. Higher is better.
|
||||
- **Diversity**: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
|
||||
- **CFM loss** (Cross-Feature Matching): Tracks `||mean(gen_features) - gt_features||^2` as a diagnostic. This is the quantity that EBFT ultimately minimizes.
|
||||
|
||||
## Modes
|
||||
|
||||
EBFT supports three operational modes, each suited to different use cases.
|
||||
|
||||
### Structured Mode (Sync)
|
||||
|
||||
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
|
||||
|
||||
```
|
||||
GPU 0: vLLM Server (generates completions, receives weight syncs)
|
||||
GPU 1: Trainer (feature extraction, reward computation, GRPO training)
|
||||
```
|
||||
|
||||
**When to use**: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
|
||||
|
||||
### Structured Mode (Async)
|
||||
|
||||
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
|
||||
|
||||
**When to use**: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by `vllm_sync_interval`).
|
||||
|
||||
### Strided Mode
|
||||
|
||||
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
|
||||
|
||||
```
|
||||
Single GPU: Base model + LoRA adapter
|
||||
- Strided block-parallel generation (flex_attention)
|
||||
- Feature extraction via disable_adapter()
|
||||
- No vLLM needed
|
||||
```
|
||||
|
||||
**When to use**: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Structured Mode
|
||||
|
||||
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
|
||||
|
||||
**Step 1**: Create a config file `ebft_quickstart.yaml`:
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2-0.5B-Instruct
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 3
|
||||
use_data_producer: true
|
||||
async_prefetch: false
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 1024
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
# Standard training settings (see getting-started.qmd for details)
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 20
|
||||
learning_rate: 5.0e-6
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
output_dir: ./outputs/ebft-quickstart
|
||||
```
|
||||
|
||||
**Step 2**: Start vLLM on GPU 0:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml
|
||||
```
|
||||
|
||||
**Step 3**: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
The `micro_batch_size` must be divisible by `num_generations`. For example, with `num_generations: 4`, valid values are 4, 8, 12, etc.
|
||||
:::
|
||||
|
||||
### Dataset Format
|
||||
|
||||
Structured mode datasets must produce two fields after the transform:
|
||||
|
||||
- `prompt`: Either a string or a list of chat messages (`[{"role": "user", "content": "..."}]`)
|
||||
- `ground_truth`: A string containing the reference completion
|
||||
|
||||
Example raw dataset row:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": "Write a function to compute fibonacci numbers.",
|
||||
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
|
||||
}
|
||||
```
|
||||
|
||||
The `ebft_opencode.transform` converts this to the required `{prompt, ground_truth}` format automatically.
|
||||
|
||||
## Feature Extraction
|
||||
|
||||
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
|
||||
|
||||
### Feature Layers
|
||||
|
||||
The `feature_layers` parameter specifies which layers to extract, as fractions of total model depth:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depth
|
||||
```
|
||||
|
||||
For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size `num_layers * hidden_dim`.
|
||||
|
||||
::: {.callout-tip}
|
||||
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default `[0.25, 0.5, 0.75]` works well across model sizes.
|
||||
:::
|
||||
|
||||
### Embed Methods
|
||||
|
||||
The `embed_method` controls how per-token hidden states are pooled into a single vector per sequence:
|
||||
|
||||
| Method | Description | Output Shape | Notes |
|
||||
|--------|-------------|-------------|-------|
|
||||
| `last_token` | Hidden state at the last non-padding token | `(B, D)` | Default. Good for autoregressive models where the last token summarizes the sequence. |
|
||||
| `mean_pooling` | Mean of all non-padding token states | `(B, D)` | Considers the entire sequence equally. |
|
||||
| `completion_mean` | Mean over completion tokens only (excludes prompt) | `(B, D)` | Focuses reward signal on generated content. Requires prompt length information. |
|
||||
| `concat` | Concatenation of states at 25%, 50%, 75% positions | `(B, 3*D)` | Captures positional structure. Higher dimensional. |
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
embed_method: completion_mean # Focus on completion features
|
||||
```
|
||||
|
||||
### SVD Whitening
|
||||
|
||||
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
use_whitening: true
|
||||
```
|
||||
|
||||
When whitening is enabled, the reward computation applies a whitening matrix `W = U @ diag(1/S) @ U^T` derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
|
||||
|
||||
::: {.callout-note}
|
||||
Singular values scale with `sqrt(batch_size)`, so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (`n_samples_per_prompt` or `num_generations`) is fixed during training.
|
||||
:::
|
||||
|
||||
### Alignment and Diversity Coefficients
|
||||
|
||||
The two reward components are weighted by coefficients:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
|
||||
diversity_coef: 1.0 # Weight for pairwise similarity penalty
|
||||
```
|
||||
|
||||
Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
|
||||
|
||||
```
|
||||
reward_j = 2 * alignment_coef * cos(gen_j, gt)
|
||||
- 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')
|
||||
```
|
||||
|
||||
Setting `diversity_coef: 0.0` disables the diversity penalty entirely, which may be appropriate when `num_generations` is small (e.g., 2).
|
||||
|
||||
## Strided Mode
|
||||
|
||||
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places **anchor points** at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
|
||||
|
||||
### How Block-Parallel Generation Works
|
||||
|
||||
Given a document of length `S` tokens:
|
||||
|
||||
1. **Anchor placement**: Starting at position `anchor_offset`, place anchors every `stride` tokens. Each anchor defines a block.
|
||||
2. **Context window**: Each block sees `context_length` tokens of preceding context from the original document.
|
||||
3. **Generation**: At each anchor, generate `generate_max_len` tokens autoregressively, conditioned only on the context window.
|
||||
4. **Parallelism**: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
|
||||
|
||||
```
|
||||
Document: [tok0, tok1, ..., tok_S]
|
||||
| | |
|
||||
anchor_0 anchor_1 anchor_2
|
||||
| | |
|
||||
[ctx][gen] [ctx][gen] [ctx][gen]
|
||||
```
|
||||
|
||||
The attention mask ensures:
|
||||
|
||||
- Prompt tokens use standard causal attention
|
||||
- Each generated block attends to its own context window and its own preceding generated tokens
|
||||
- Blocks do not attend to each other's generated tokens
|
||||
|
||||
When `flex_attention` is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
|
||||
|
||||
### Strided Mode Configuration
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8 # Tokens between anchor points
|
||||
context_length: 8 # Context window per block
|
||||
generate_max_len: 8 # Tokens to generate per block
|
||||
n_samples_per_prompt: 4 # Independent rollouts per document
|
||||
temperature: 0.6
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0 # RL policy gradient loss weight
|
||||
ce_coef: 0.03 # Cross-entropy loss on GT tokens
|
||||
advantage_estimator: rloo # rloo, group_norm, or reinforce
|
||||
min_completion_prefix: 8 # Skip anchors in prompt region
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 2048
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flex_attention: true
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required with flex_attention
|
||||
```
|
||||
|
||||
Run with a single command (no vLLM needed):
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
### Advantage Estimators
|
||||
|
||||
Strided mode supports three advantage estimation methods:
|
||||
|
||||
| Estimator | Formula | Requirements |
|
||||
|-----------|---------|-------------|
|
||||
| `rloo` | Leave-one-out baseline: `reward_j - mean(rewards_{-j})` | `n_samples_per_prompt >= 2` |
|
||||
| `group_norm` | Group normalization: `(reward_j - mean) / std` | `n_samples_per_prompt >= 2` |
|
||||
| `reinforce` | Raw reward as advantage (no baseline) | Works with `n_samples_per_prompt = 1` |
|
||||
|
||||
::: {.callout-warning}
|
||||
When `n_samples_per_prompt: 1`, the trainer automatically falls back to `reinforce` and disables the diversity penalty (which requires multiple samples).
|
||||
:::
|
||||
|
||||
### Strided Mode Constraints
|
||||
|
||||
- **`flex_attention: true`** is strongly recommended. Without it, dense 4D masks consume significantly more memory.
|
||||
- **`torch_compile: true`** must NOT be set. `flex_attention` compiles its own kernels internally; adding `torch_compile` causes conflicts and OOM.
|
||||
- **Gradient checkpointing** must use `use_reentrant: true`. Non-reentrant checkpointing causes `CheckpointError` with `flex_attention` block masks.
|
||||
- **`activation_offloading`** is incompatible with `flex_attention`.
|
||||
|
||||
### Cross-Entropy Loss
|
||||
|
||||
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
ce_coef: 0.03 # Small CE coefficient
|
||||
rl_coef: 1.0 # RL loss coefficient
|
||||
```
|
||||
|
||||
The total loss is `rl_coef * rl_loss + ce_coef * ce_loss`. For structured mode, `ce_coef` is typically `0.0` since vLLM generation provides sufficient learning signal.
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
EBFT provides several built-in dataset transforms in `src/axolotl/prompt_strategies/ebft/`.
|
||||
|
||||
### Built-In Transforms
|
||||
|
||||
| Transform | Input Format | Output Fields | Use Case |
|
||||
|-----------|-------------|---------------|----------|
|
||||
| `ebft_opencode.transform` | `{input, output}` | `{prompt, ground_truth}` | OpenCodeInstruct, structured QA |
|
||||
| `ebft_strided_structured.transform` | `{input, output}` | `{input_ids, labels, prompt_length}` | Strided mode with structured data |
|
||||
| `ebft_strided_chat.transform` | `{messages: [...]}` | `{input_ids, labels, prompt_length}` | Strided mode with chat data |
|
||||
| `ebft_chat_multiturn.transform` | `{messages: [...]}` | `{prompt, ground_truth, remaining_turns}` | Multi-turn: first-turn target |
|
||||
| `ebft_chat_multiturn.transform_last_turn` | `{messages: [...]}` | `{prompt, ground_truth}` | Multi-turn: last-turn target |
|
||||
| `ebft_chat_multiturn.transform_all_turns` | `{messages: [...]}` | `{prompt[], ground_truth[]}` | Multi-turn: one example per turn |
|
||||
| `ebft_reasoning.transform` | `{messages: [...]}` (with `<think>`) | `{prompt, ground_truth}` | Reasoning/thinking datasets |
|
||||
|
||||
### Structured Mode Datasets
|
||||
|
||||
For structured (sync/async) mode, the transform must produce `prompt` and `ground_truth` fields:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
```
|
||||
|
||||
### Multi-Turn Datasets
|
||||
|
||||
Multi-turn transforms extract conversation data for sequential rollout. The `transform` variant targets the first assistant turn, while `transform_last_turn` targets the final turn:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/multiturn-dataset
|
||||
type: ebft_chat_multiturn.transform
|
||||
```
|
||||
|
||||
When `remaining_turns` is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
|
||||
|
||||
### Strided Mode Datasets
|
||||
|
||||
Strided transforms tokenize the full document and produce `input_ids`, `labels`, and `prompt_length`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
```
|
||||
|
||||
### Custom Transforms
|
||||
|
||||
To use your own dataset format, write a transform function:
|
||||
|
||||
```python
|
||||
def transform(cfg, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]}],
|
||||
"ground_truth": example["answer"],
|
||||
}
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
```
|
||||
|
||||
The `"__all__"` sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: your_module.transform
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Common Parameters (All Modes)
|
||||
|
||||
These parameters are set under the `ebft:` key in the YAML config.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `mode` | `"structured"` or `"strided"` | `"structured"` | EBFT operating mode |
|
||||
| `feature_layers` | `list[float]` | `[0.25, 0.5, 0.75]` | Fractional layer depths for feature extraction |
|
||||
| `embed_method` | `string` | `"last_token"` | Pooling method: `last_token`, `mean_pooling`, `completion_mean`, or `concat` |
|
||||
| `use_whitening` | `bool` | `false` | Apply SVD whitening to feature embeddings before reward computation |
|
||||
| `alignment_coef` | `float` | `1.0` | Weight for alignment reward (cosine similarity with ground truth) |
|
||||
| `diversity_coef` | `float` | `1.0` | Weight for diversity penalty (pairwise dot product between samples) |
|
||||
| `ce_coef` | `float` | `0.0` | Cross-entropy loss coefficient on ground-truth tokens |
|
||||
| `adaptive_max_tokens` | `bool` | `true` | Dynamically set vLLM `max_tokens` based on ground-truth length (structured mode) |
|
||||
| `gt_length_multiplier` | `float` | `1.5` | Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1) |
|
||||
|
||||
### Strided Mode Parameters
|
||||
|
||||
These additional parameters apply only when `mode: strided`.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `stride` | `int` | `8` | Number of tokens between anchor points (must be >= 1) |
|
||||
| `context_length` | `int` | `8` | Context window size for each generated block (must be >= 1) |
|
||||
| `generate_max_len` | `int` | `8` | Number of tokens to generate per block (must be >= 1) |
|
||||
| `n_samples_per_prompt` | `int` | `4` | Number of independent rollouts per document (must be >= 1) |
|
||||
| `temperature` | `float` | `0.6` | Sampling temperature for strided generation |
|
||||
| `top_p` | `float` | `1.0` | Top-p nucleus sampling threshold |
|
||||
| `rl_coef` | `float` | `1.0` | RL policy gradient loss coefficient |
|
||||
| `advantage_estimator` | `string` | `"rloo"` | Advantage estimation method: `rloo`, `group_norm`, or `reinforce` |
|
||||
| `min_completion_prefix` | `int` | `0` | Minimum tokens into the completion span before placing anchors |
|
||||
|
||||
### Structured Mode TRL Parameters
|
||||
|
||||
These are set under the `trl:` key and control the GRPO training loop.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `num_generations` | `int` | -- | Number of completions generated per prompt |
|
||||
| `max_completion_length` | `int` | -- | Maximum tokens per generated completion |
|
||||
| `temperature` | `float` | `0.7` | Sampling temperature for vLLM generation |
|
||||
| `use_vllm` | `bool` | -- | Enable vLLM generation backend |
|
||||
| `vllm_lora_sync` | `bool` | `false` | Sync LoRA adapters via filesystem (recommended) |
|
||||
| `vllm_sync_interval` | `int` | `1` | Steps between weight syncs to vLLM |
|
||||
| `use_data_producer` | `bool` | -- | Required for sync mode with LoRA sync |
|
||||
| `async_prefetch` | `bool` | `false` | Enable async generation (overlaps with training) |
|
||||
| `streaming_partial_batch` | `bool` | `false` | Score groups incrementally (async mode) |
|
||||
| `skip_zero_advantage_batches` | `bool` | `false` | Skip micro-batches where all advantages are zero |
|
||||
| `scale_rewards` | `bool` | -- | Normalize rewards within each prompt group |
|
||||
| `loss_type` | `string` | `"grpo"` | Loss type for policy optimization |
|
||||
| `epsilon` | `float` | `0.2` | Clipping parameter for importance sampling |
|
||||
|
||||
### Stop Tokens
|
||||
|
||||
vLLM needs explicit stop token IDs for generation. Common configurations:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
generation_kwargs:
|
||||
stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>
|
||||
```
|
||||
|
||||
### Multi-Turn Chat Settings
|
||||
|
||||
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent `<think>` tags in completions:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Key Metrics
|
||||
|
||||
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
|
||||
|
||||
| Metric | Healthy Range | Interpretation |
|
||||
|--------|--------------|----------------|
|
||||
| `ebft/alignment` | 0.3 -- 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
|
||||
| `ebft/diversity` | 0.01 -- 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
|
||||
| `ebft/cfm_loss` | Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
|
||||
| `ebft/reward` | Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
|
||||
| `grad_norm` | 0.1 -- 3.0 | Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability. |
|
||||
| `entropy` | 0.05 -- 0.5 | Policy entropy. Values below 0.01 suggest mode collapse. |
|
||||
| `IS ratio min` | Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase `vllm_sync_interval`. |
|
||||
|
||||
### Console Log Example
|
||||
|
||||
During training, you will see periodic EBFT reward logs:
|
||||
|
||||
```
|
||||
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
|
||||
```
|
||||
|
||||
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|-------------|-----|
|
||||
| `alignment` stays below 0.1 | Feature layers not capturing useful information | Try different `feature_layers` or `embed_method` |
|
||||
| `diversity` exceeds 1.0 | Mode collapse -- generations are too similar | Increase `diversity_coef` or `temperature` |
|
||||
| `reward` stuck at -1.0 | Diversity penalty dominates alignment | Reduce `diversity_coef` or increase `alignment_coef` |
|
||||
| `grad_norm` consistently 0.0 | All micro-batches have zero advantage | Increase `num_generations` or check data quality |
|
||||
| `CheckpointError` in strided mode | Incompatible gradient checkpointing settings | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
| OOM during training | Logits tensor too large | Reduce `sequence_len` or `micro_batch_size`; strided mode uses chunked lm_head to mitigate this |
|
||||
| vLLM 500 errors | `truncate_prompt_tokens` not supported | Ensure you are using `axolotl vllm-serve` (not `trl vllm-serve`) |
|
||||
|
||||
### Feature Network Memory
|
||||
|
||||
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the `disable_adapter()` context manager. This saves an entire model copy in VRAM (approximately 1--16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
|
||||
|
||||
::: {.callout-note}
|
||||
The `disable_adapter()` approach relies on an invariant: `merge_adapter()` is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
|
||||
:::
|
||||
|
||||
## Examples
|
||||
|
||||
Complete example configurations are available in `examples/ebft/`:
|
||||
|
||||
| Config | Model | Mode | Description |
|
||||
|--------|-------|------|-------------|
|
||||
| `llama-1b-ebft-strided-structured.yaml` | Llama 3.2 1B | Strided | Single-GPU strided training on code data |
|
||||
| `qwen3-4b-ebft-structured.yaml` | Qwen3 4B | Structured (sync) | Two-GPU structured training |
|
||||
| `qwen3-4b-ebft-structured-async.yaml` | Qwen3 4B | Structured (async) | Two-GPU async training with prefetch |
|
||||
| `qwen3-8b-ebft-structured.yaml` | Qwen3 8B | Structured (sync) | Two-GPU structured training for larger model |
|
||||
| `qwen35-4b-ebft-structured.yaml` | Qwen3.5 4B | Structured (sync) | Two-GPU with Qwen3.5 |
|
||||
| `qwen35-4b-ebft-structured-async.yaml` | Qwen3.5 4B | Structured (async) | Two-GPU async with Qwen3.5 |
|
||||
| `qwen35-9b-ebft-structured.yaml` | Qwen3.5 9B | Structured (sync) | Two-GPU structured for 9B model |
|
||||
@@ -170,26 +170,17 @@ More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, explore these guides based on what you want to do:
|
||||
Now that you have the basics, you might want to:
|
||||
|
||||
**Choose your path:**
|
||||
- Try different model architectures
|
||||
- Experiment with hyperparameters
|
||||
- Use more advanced training methods
|
||||
- Scale up to larger models
|
||||
|
||||
- [Choosing a Fine-Tuning Method](choosing_method.qmd) — SFT vs LoRA vs QLoRA vs GRPO vs DPO, with hardware recommendations
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
**Core guides:**
|
||||
|
||||
- [Dataset Loading](dataset_loading.qmd) — Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) — Working with different data formats
|
||||
- [Optimizations](optimizations.qmd) — Flash attention, gradient checkpointing, sample packing
|
||||
- [Training Stability & Debugging](training_stability.qmd) — Monitoring metrics, fixing NaN, OOM debugging
|
||||
|
||||
**Advanced training methods:**
|
||||
|
||||
- [RLHF / Preference Learning](rlhf.qmd) — DPO, KTO, GRPO, EBFT
|
||||
- [GRPO Training](grpo.qmd) — RL with custom rewards and vLLM generation
|
||||
- [vLLM Serving](vllm_serving.qmd) — Setting up vLLM for GRPO
|
||||
|
||||
**Scaling up:**
|
||||
|
||||
- [Multi-GPU Training](multi-gpu.qmd) — DeepSpeed, FSDP, DDP
|
||||
- [Multi-Node Training](multi-node.qmd) — Distributed training across machines
|
||||
- [Configuration Guide](config-reference.qmd) - Full configuration options
|
||||
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
|
||||
611
docs/grpo.qmd
611
docs/grpo.qmd
@@ -1,611 +0,0 @@
|
||||
---
|
||||
title: "GRPO Training"
|
||||
description: "Group Relative Policy Optimization — a reinforcement learning method for training language models with verifiable reward functions."
|
||||
order: 8
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).
|
||||
|
||||
Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.
|
||||
|
||||
Axolotl's GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.
|
||||
|
||||
|
||||
## Architecture
|
||||
|
||||
GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.
|
||||
|
||||
```
|
||||
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
|
||||
┌──────────────────────┐ ┌──────────────────────────────────┐
|
||||
│ vLLM Server │ │ Trainer │
|
||||
│ │ HTTP │ │
|
||||
│ Serves base model │◄────────────►│ Background thread: │
|
||||
│ + LoRA adapter │ /generate │ Send prompts to vLLM │
|
||||
│ │ /set_lora │ Pad & collate completions │
|
||||
│ Punica kernels for │ │ │
|
||||
│ LoRA inference │ │ Main thread: │
|
||||
│ │ │ Score completions (rewards) │
|
||||
└──────────────────────┘ │ Compute policy log-probs │
|
||||
│ Calculate advantages │
|
||||
│ PPO-clip gradient update │
|
||||
│ Sync LoRA weights to vLLM │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Data flow for each training step:**
|
||||
|
||||
1. The background thread sends prompts to vLLM, which generates `num_generations` completions per prompt.
|
||||
2. The main thread scores completions using your reward functions.
|
||||
3. Advantages are computed within each prompt group (group-relative normalization).
|
||||
4. Policy log-probabilities are computed by running a forward pass on the training model.
|
||||
5. The PPO-clip loss is computed and gradients are applied.
|
||||
6. Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.
|
||||
|
||||
With async prefetch enabled, step 1 for the *next* batch runs concurrently with steps 2-6 for the *current* batch.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.
|
||||
|
||||
### 1. Write a reward module
|
||||
|
||||
Create a file called `rewards.py` in your working directory:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
import re
|
||||
|
||||
|
||||
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
||||
"""Check if the completion contains the correct numerical answer."""
|
||||
rewards = []
|
||||
for completion, correct in zip(completions, answer):
|
||||
text = completion[0]["content"]
|
||||
# Extract the last number from the completion
|
||||
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
|
||||
predicted = numbers[-1] if numbers else ""
|
||||
rewards.append(1.0 if predicted == str(correct) else 0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs) -> list[float]:
|
||||
"""Reward completions that use a structured thinking format."""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
text = completion[0]["content"]
|
||||
has_think = "<think>" in text and "</think>" in text
|
||||
has_answer = "<answer>" in text and "</answer>" in text
|
||||
rewards.append(1.0 if has_think and has_answer else 0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
def prompt_transform(cfg, *args, **kwargs):
|
||||
"""Convert GSM8K dataset rows into chat prompts."""
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
|
||||
{"role": "user", "content": example["question"]},
|
||||
],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
### 2. Write the config
|
||||
|
||||
Create `config.yaml`:
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
max_model_len: 2048
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
use_data_producer: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
vllm_lora_sync: true
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
temperature: 0.7
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
- rewards.format_reward
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.5
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
skip_prepare_dataset: true
|
||||
val_set_size: 0.0
|
||||
sequence_len: 512
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 200
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: true
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
output_dir: ./grpo-output
|
||||
logging_steps: 1
|
||||
```
|
||||
|
||||
### 3. Start vLLM and train
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM server on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Wait 30-90 seconds for model loading and CUDA graph capture
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
:::{.callout-tip}
|
||||
Use `tmux` or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.
|
||||
:::
|
||||
|
||||
|
||||
## Custom Reward Functions
|
||||
|
||||
### Function signature
|
||||
|
||||
TRL calls reward functions with this signature:
|
||||
|
||||
```python
|
||||
def my_reward(completions, **kwargs) -> list[float]:
|
||||
```
|
||||
|
||||
- `completions` is a list of single-element lists, where each element is a dict `{"role": "assistant", "content": "..."}`. So `completions[i][0]["content"]` gives you the text of the i-th completion.
|
||||
- `**kwargs` contains all dataset columns that were *not* removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.
|
||||
- Return a `list[float]` with the same length as `completions`. You may return `None` for individual elements to exclude them from aggregation.
|
||||
|
||||
### Example: accuracy reward with answer extraction
|
||||
|
||||
```python
|
||||
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
||||
rewards = []
|
||||
for completion, correct_answer in zip(completions, answer):
|
||||
text = completion[0]["content"]
|
||||
# Extract answer from <answer>...</answer> tags
|
||||
match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
||||
predicted = match.group(1).strip() if match else ""
|
||||
rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
|
||||
return rewards
|
||||
```
|
||||
|
||||
### Example: length penalty
|
||||
|
||||
```python
|
||||
def length_penalty(completions, **kwargs) -> list[float]:
|
||||
"""Penalize very short or very long completions."""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
length = len(completion[0]["content"])
|
||||
if length < 50:
|
||||
rewards.append(-0.5)
|
||||
elif length > 2000:
|
||||
rewards.append(-0.2)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
return rewards
|
||||
```
|
||||
|
||||
### Multiple rewards and weighting
|
||||
|
||||
You can combine multiple reward functions with different weights:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
- rewards.format_reward
|
||||
- rewards.length_penalty
|
||||
reward_weights:
|
||||
- 1.0 # accuracy is most important
|
||||
- 0.5 # format compliance
|
||||
- 0.1 # mild length preference
|
||||
```
|
||||
|
||||
Rewards are combined by the `multi_objective_aggregation` strategy:
|
||||
|
||||
- `sum_then_normalize` (default): weights and sums all rewards first, then normalizes across the group.
|
||||
- `normalize_then_sum` (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
multi_objective_aggregation: normalize_then_sum
|
||||
```
|
||||
|
||||
### Dataset transforms
|
||||
|
||||
The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:
|
||||
|
||||
```python
|
||||
def prompt_transform(cfg, *args, **kwargs):
|
||||
def map_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": example["question"]},
|
||||
],
|
||||
# Keep 'answer' column for the reward function
|
||||
"answer": example["answer"],
|
||||
}
|
||||
# Remove columns consumed by the transform; keep columns needed by rewards
|
||||
return map_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
The transform returns a tuple of `(map_function, kwargs_dict)`. The `remove_columns` in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via `**kwargs` (like `answer`) must *not* be removed.
|
||||
|
||||
:::{.callout-warning}
|
||||
The reward module must be importable from the directory where you run `axolotl train`. If your reward file is `rewards.py`, the import path is `rewards.accuracy_reward`. If it is inside a package `my_rewards/scoring.py`, use `my_rewards.scoring.accuracy_reward`.
|
||||
:::
|
||||
|
||||
### Reward models (neural network rewards)
|
||||
|
||||
Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
- OpenAssistant/reward-model-deberta-v3-large-v2
|
||||
- rewards.format_reward
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.3
|
||||
```
|
||||
|
||||
### Using math_verify
|
||||
|
||||
The `math_verify` library provides robust mathematical answer verification but uses `signal.alarm()` internally, which only works in the main thread. If you use `math_verify` in a reward function, set `reward_num_workers` to use subprocess workers:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4
|
||||
```
|
||||
|
||||
Each worker runs in its own subprocess with its own main thread, so `signal.alarm()` works correctly.
|
||||
|
||||
|
||||
## vLLM Setup
|
||||
|
||||
GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see [vLLM Serving](vllm_serving.qmd).
|
||||
|
||||
The minimal setup:
|
||||
|
||||
```yaml
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.85
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention
|
||||
vllm_sync_interval: 5 # Sync weights every 5 steps
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: training
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
vLLM must be restarted between experiments — stale weight syncs corrupt server state. See [Restart Requirements](vllm_serving.qmd#sec-restart).
|
||||
:::
|
||||
|
||||
|
||||
## Async Training Features
|
||||
|
||||
Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.
|
||||
|
||||
### Enabling async prefetch
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
```
|
||||
|
||||
- `use_data_producer: true` enables the data producer protocol (required for all async features).
|
||||
- `async_prefetch: true` runs generation in a background thread.
|
||||
- `prefetch_depth` controls how many batches to prefetch ahead (1 is usually sufficient).
|
||||
- `vllm_sync_interval` controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
|
||||
|
||||
:::{.callout-tip}
|
||||
Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable `vllm_importance_sampling_correction: true` when using `async_prefetch: true`.
|
||||
:::
|
||||
|
||||
### Streaming partial batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
streaming_min_groups: 1
|
||||
```
|
||||
|
||||
`streaming_min_groups` controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.
|
||||
|
||||
### Zero-advantage batch skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
This is enabled by default and logged as `skipped_zero_adv_batches` in training metrics. It is a safety net, not a major optimization -- it only saves significant time when the model cannot solve any prompts in the batch.
|
||||
|
||||
### Replay buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100
|
||||
replay_recompute_logps: true
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
When `replay_recompute_logps: false`, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default `true` unless you have a specific reason to disable it.
|
||||
:::
|
||||
|
||||
### Deferred re-rolling
|
||||
|
||||
Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
Set `reroll_start_fraction: 1.0` to disable. This is most useful for tasks where the model starts weak but steadily improves.
|
||||
|
||||
### Parallel reward workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (like `math_verify`) only work in the main thread. Parallel reward workers run each function in its own subprocess:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4
|
||||
```
|
||||
|
||||
Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient -- the overhead of IPC can exceed the computation time.
|
||||
|
||||
|
||||
## Importance Sampling and Off-Policy Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true
|
||||
importance_sampling_level: token # 'token' recommended (especially with Liger kernel)
|
||||
off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policy
|
||||
```
|
||||
|
||||
Use `token` level IS. Sequence-level has numerical issues with Liger's chunked computation. The `off_policy_mask_threshold` (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.
|
||||
|
||||
For detailed coverage of IS modes (`token_mask`, `token_truncate`, etc.), capping, and bias-corrected KL, see [vLLM Serving — IS Correction](vllm_serving.qmd#sec-weight-sync).
|
||||
|
||||
|
||||
## Scaling
|
||||
|
||||
### FP8 training
|
||||
|
||||
FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.
|
||||
|
||||
```yaml
|
||||
fp8: true
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
FP8 requires patching for zero-padding edge cases. The `act_quant_kernel` can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.
|
||||
:::
|
||||
|
||||
### FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
# GPU 0: vLLM
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# GPUs 0,1: Training (FSDP will use both visible GPUs)
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
`async_prefetch: true` can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set `async_prefetch: false`.
|
||||
:::
|
||||
|
||||
### DeepSpeed ZeRO-3
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3
|
||||
```
|
||||
|
||||
:::{.callout-note}
|
||||
DeepSpeed ZeRO-3 requires `use_reentrant: true` for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3's parameter partitioning.
|
||||
:::
|
||||
|
||||
### Multi-GPU considerations
|
||||
|
||||
| Concern | Recommendation |
|
||||
|---------|---------------|
|
||||
| vLLM GPU allocation | Dedicate one or more GPUs to vLLM; do not share with trainer GPUs |
|
||||
| Weight sync contention | Use `vllm_lora_sync: true` to avoid NCCL contention between training and vLLM |
|
||||
| FSDP + async | Use `async_prefetch: false` or rely on rank-0-only background generation |
|
||||
| DeepSpeed + gradient checkpoint | Must use `use_reentrant: true` |
|
||||
| OOM during scoring | Reduce `micro_batch_size` or `num_generations`. The logits tensor scales with `batch_size * vocab_size` |
|
||||
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
For detailed metric ranges, failure diagnosis, and OOM debugging, see [Training Stability & Debugging](training_stability.qmd).
|
||||
|
||||
Quick health checks during GRPO training:
|
||||
|
||||
- `rewards/*/mean` should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone
|
||||
- `reward_std` should be > 0 on most steps — all-zero means no learning signal
|
||||
- `entropy` in 0.05-0.5 — below 0.01 suggests mode collapse
|
||||
- `grad_norm` in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
|
||||
|
||||
:::{.callout-tip}
|
||||
Pipe training output to a log file: `axolotl train config.yaml 2>&1 | tee /tmp/training.log`
|
||||
:::
|
||||
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
All GRPO-specific options live under the `trl:` key in your config. Standard training options (`learning_rate`, `micro_batch_size`, etc.) are set at the top level as usual.
|
||||
|
||||
### Core GRPO
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `use_vllm` | bool | `false` | Enable vLLM for generation |
|
||||
| `vllm_mode` | `"server"` or `"colocate"` | `null` | vLLM deployment mode |
|
||||
| `vllm_server_host` | str | `"0.0.0.0"` | vLLM server hostname |
|
||||
| `vllm_server_port` | int | `8000` | vLLM server port |
|
||||
| `vllm_server_timeout` | int | `null` | Timeout (seconds) for vLLM responses |
|
||||
| `num_generations` | int | `null` | Completions generated per prompt |
|
||||
| `generation_batch_size` | int | `null` | Number of unique prompts per generation step |
|
||||
| `max_completion_length` | int | `null` | Maximum tokens per completion |
|
||||
| `beta` | float | `null` | KL penalty coefficient |
|
||||
| `num_iterations` | int | `null` | Iterations per batch (mu in the GRPO paper) |
|
||||
| `epsilon` | float | `null` | PPO clipping lower bound |
|
||||
| `epsilon_high` | float | `null` | PPO clipping upper bound |
|
||||
| `loss_type` | str | `null` | Loss formulation: `grpo`, `bnpo`, or `dr_grpo` |
|
||||
| `scale_rewards` | bool | `true` | Normalize rewards by standard deviation |
|
||||
| `mask_truncated_completions` | bool | `false` | Exclude truncated completions from loss |
|
||||
|
||||
### Reward functions
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `reward_funcs` | list[str] | `null` | Import paths to reward functions or HF model IDs |
|
||||
| `reward_weights` | list[float] | `null` | Relative weights for each reward function |
|
||||
| `multi_objective_aggregation` | str | `null` | `"sum_then_normalize"` (GRPO) or `"normalize_then_sum"` (GDPO) |
|
||||
| `rollout_func` | str | `null` | Import path to custom rollout function for OpenEnv-style tasks |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `temperature` | float | `null` | Sampling temperature |
|
||||
| `top_p` | float | `null` | Nucleus sampling probability |
|
||||
| `top_k` | int | `null` | Top-k sampling |
|
||||
| `min_p` | float | `null` | Minimum probability threshold |
|
||||
| `repetition_penalty` | float | `null` | Penalty for repeated tokens |
|
||||
| `generation_kwargs` | dict | `null` | Additional vLLM SamplingParams (e.g., `stop_token_ids`) |
|
||||
| `chat_template_kwargs` | dict | `null` | Chat template kwargs (e.g., `{enable_thinking: false}`) |
|
||||
| `vllm_guided_decoding_regex` | str | `null` | Regex constraint for guided decoding |
|
||||
|
||||
### Async pipeline
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `use_data_producer` | bool | `false` | Enable data producer protocol (required for async features) |
|
||||
| `async_prefetch` | bool | `false` | Generate next batch in background thread |
|
||||
| `prefetch_depth` | int | `null` | Number of batches to prefetch ahead |
|
||||
| `vllm_sync_interval` | int | `null` | Sync LoRA weights to vLLM every N steps |
|
||||
| `vllm_lora_sync` | bool | `false` | Use filesystem LoRA sync instead of NCCL merge |
|
||||
| `streaming_partial_batch` | bool | `null` | Score prompt groups incrementally |
|
||||
| `streaming_min_groups` | int | `null` | Minimum groups per streaming chunk |
|
||||
| `skip_zero_advantage_batches` | bool | `true` | Skip micro-batches with zero learning signal |
|
||||
| `reward_num_workers` | int | `1` | Subprocess workers for reward computation |
|
||||
| `vllm_enable_sleep_mode` | bool | `null` | Offload vLLM weights when idle (colocate mode) |
|
||||
|
||||
### Importance sampling
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `vllm_importance_sampling_correction` | bool | `null` | Enable IS correction for async distribution shift |
|
||||
| `importance_sampling_level` | `"token"` or `"sequence"` | `null` | Granularity of IS ratios. Use `token` with Liger |
|
||||
| `vllm_importance_sampling_mode` | str | `null` | `token_mask`, `token_truncate`, `sequence_mask`, or `sequence_truncate` |
|
||||
| `vllm_importance_sampling_cap` | float | `null` | Cap C for IS ratio clipping/masking |
|
||||
| `off_policy_mask_threshold` | float | `null` | KL threshold for off-policy sequence masking (OPSM) |
|
||||
| `use_bias_correction_kl` | bool | `null` | Apply IS correction to KL divergence term |
|
||||
|
||||
### Replay and re-roll
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `replay_buffer_size` | int | `0` | Max cached high-signal groups. 0 = disabled |
|
||||
| `replay_recompute_logps` | bool | `true` | Recompute log-probs for replayed data with current model |
|
||||
| `reroll_start_fraction` | float | `1.0` | Start re-rolling failed prompts after this fraction of training. 1.0 = disabled |
|
||||
| `reroll_max_groups` | int | `1` | Max prompt groups to replace with re-rolls per batch |
|
||||
|
||||
### Reference model
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `sync_ref_model` | bool | `false` | Periodically sync reference model with training model |
|
||||
| `ref_model_mixup_alpha` | float | `0.9` | EMA coefficient for reference model sync |
|
||||
| `ref_model_sync_steps` | int | `64` | Sync reference model every N steps |
|
||||
|
||||
### Logging
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `log_completions` | bool | `false` | Log sample completions to W&B |
|
||||
| `num_completions_to_print` | int | `null` | Number of completions to print per step |
|
||||
| `use_liger_loss` | bool | `null` | Use Liger fused kernel for GRPO loss (reduces VRAM) |
|
||||
310
docs/rlhf.qmd
310
docs/rlhf.qmd
@@ -16,12 +16,8 @@ feedback. Various methods include, but not limited to:
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo) — see also the [GRPO deep dive](grpo.qmd) for async features, custom rewards, and scaling
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
||||
- [Energy-Based Fine-Tuning (EBFT)](#ebft) — see also the [EBFT guide](ebft.qmd) for detailed mode comparisons and configuration
|
||||
- [NeMo Gym Integration](#nemo-gym-integration)
|
||||
|
||||
For help choosing between these methods, see [Choosing a Fine-Tuning Method](choosing_method.qmd).
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -517,7 +513,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code). For a comprehensive guide covering async training, custom rewards, importance sampling, and scaling, see the [GRPO deep dive](grpo.qmd).
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
|
||||
:::
|
||||
|
||||
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
@@ -925,7 +921,7 @@ gradient_checkpointing_kwargs:
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
@@ -1041,306 +1037,6 @@ simpo_gamma: 0.5 # default in CPOTrainer
|
||||
|
||||
This method uses the same dataset format as [DPO](#dpo).
|
||||
|
||||
### EBFT {#ebft}
|
||||
|
||||
::: {.callout-tip}
|
||||
For a detailed guide on EBFT modes, feature extraction, and configuration, see the [EBFT guide](ebft.qmd).
|
||||
:::
|
||||
|
||||
EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
||||
|
||||
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
||||
|
||||
**Key advantages:**
|
||||
|
||||
- No reward model or verifier required — works on any (prompt, completion) data
|
||||
- Applicable to non-verifiable tasks (code, translation, creative writing)
|
||||
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
||||
|
||||
EBFT supports two modes:
|
||||
|
||||
- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
|
||||
- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.
|
||||
|
||||
#### Structured Mode
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0 # Cosine similarity reward weight
|
||||
diversity_coef: 1.0 # Pairwise dot product penalty
|
||||
ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true # LoRA adapter sync (recommended)
|
||||
vllm_sync_interval: 3
|
||||
use_data_producer: true
|
||||
async_prefetch: true # Set false for sync mode
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 2048
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
#### Strided Mode
|
||||
|
||||
For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.03
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
flash_attention: false
|
||||
flex_attention: true # Strided mode uses flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for flex_attention
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.
|
||||
:::
|
||||
|
||||
#### EBFT Configuration Reference
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) |
|
||||
| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` |
|
||||
| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions |
|
||||
| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight |
|
||||
| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight |
|
||||
| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens |
|
||||
| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) |
|
||||
| `ebft.stride` | — | Tokens between anchor points (strided mode) |
|
||||
| `ebft.context_length` | — | Context window per block (strided mode) |
|
||||
| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) |
|
||||
| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) |
|
||||
| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) |
|
||||
|
||||
### NeMo Gym Integration
|
||||
|
||||
[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`).
|
||||
|
||||
#### Single-Turn (Simplest)
|
||||
|
||||
For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: false # Colocate mode (single GPU)
|
||||
num_generations: 4
|
||||
max_completion_length: 128
|
||||
temperature: 0.9
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_dir: ~/Gym
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
server_name: reasoning_gym
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start NeMo Gym resource server
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
|
||||
"+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 2: Train
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined.
|
||||
:::
|
||||
|
||||
#### Multi-Turn with Async GRPO (Recommended)
|
||||
|
||||
For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-0.6B
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: server
|
||||
vllm_server_host: localhost
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 5
|
||||
use_data_producer: true
|
||||
async_prefetch: true # 3x speedup
|
||||
num_generations: 4
|
||||
max_completion_length: 512
|
||||
temperature: 0.8
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_multi_turn: true
|
||||
nemo_gym_verify_timeout: 120
|
||||
nemo_gym_datasets:
|
||||
- path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
server_name: example_single_tool_call
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.85
|
||||
max_model_len: 2048
|
||||
```
|
||||
|
||||
Multi-turn requires three services running:
|
||||
|
||||
```bash
|
||||
# Terminal 1: vLLM with LoRA + tool calling
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B --max-model-len 2048 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--enable-lora --max-lora-rank 64 \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
|
||||
# Terminal 2: NeMo Gym servers (resource + model proxy + agent)
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[configs/axolotl_tool_calling.yaml]" \
|
||||
"+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 3: Training
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format.
|
||||
:::
|
||||
|
||||
#### NeMo Gym Prerequisites
|
||||
|
||||
```bash
|
||||
# Clone and set up NeMo Gym
|
||||
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
|
||||
cd ~/Gym
|
||||
uv venv --python 3.12 && source .venv/bin/activate && uv sync
|
||||
|
||||
# Fix pycosat build (GCC 13+)
|
||||
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
|
||||
```
|
||||
|
||||
#### NeMo Gym Configuration Reference
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration |
|
||||
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
|
||||
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
|
||||
| `nemo_gym_head_port` | int | `11000` | Head server port |
|
||||
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` |
|
||||
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
|
||||
| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` |
|
||||
|
||||
#### Reward Functions
|
||||
|
||||
| Function | Mode | Description |
|
||||
|----------|------|-------------|
|
||||
| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward |
|
||||
| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` |
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
---
|
||||
title: "Training Stability & Debugging"
|
||||
order: 15
|
||||
description: "Guide to monitoring, debugging, and stabilizing training runs in axolotl"
|
||||
---
|
||||
|
||||
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
|
||||
|
||||
## Monitoring Training
|
||||
|
||||
### Key Metrics for SFT
|
||||
|
||||
Every SFT run should be monitored through at least these four metrics:
|
||||
|
||||
| Metric | What It Tells You | Healthy Range |
|
||||
|--------|-------------------|---------------|
|
||||
| `train/loss` | How well the model fits training data | Decreasing; typically 0.5--2.0 for chat fine-tuning |
|
||||
| `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
|
||||
| `grad_norm` | Gradient magnitude | 0.1--10.0; spikes above 100 indicate instability |
|
||||
| `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) |
|
||||
|
||||
::: {.callout-tip}
|
||||
## Set Up Logging Early
|
||||
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
|
||||
|
||||
```yaml
|
||||
wandb_project: my-project
|
||||
wandb_run_id: # optional, for resuming
|
||||
logging_steps: 1
|
||||
```
|
||||
:::
|
||||
|
||||
### Key Metrics for RL (GRPO)
|
||||
|
||||
GRPO training logs a richer set of metrics. These are the critical ones:
|
||||
|
||||
| Metric | Healthy Range | Red Flag |
|
||||
|--------|---------------|----------|
|
||||
| `rewards/<name>/mean` | > 0.15 within 20 steps | Stays at 0 -- reward function is broken or task is too hard |
|
||||
| `reward_std` | > 0 on most steps | Always 0 -- no learning signal (all completions get the same reward) |
|
||||
| `frac_reward_zero_std` | < 0.8 | 1.0 on every step -- zero-advantage skip fires constantly, no gradient updates |
|
||||
| `grad_norm` | 0.001--1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
|
||||
| `entropy` | 0.05--0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
|
||||
| `kl` | 0.0--0.5 | > 2.0 suggests policy has diverged too far from reference |
|
||||
| `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
|
||||
| `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` |
|
||||
| `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive |
|
||||
| `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking |
|
||||
| `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` -- increase it |
|
||||
|
||||
::: {.callout-note}
|
||||
## EBFT-Specific Metrics
|
||||
For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.3--0.9), `ebft/diversity` (healthy 0.01--0.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10).
|
||||
:::
|
||||
|
||||
## SFT Stability
|
||||
|
||||
### Loss Plateau
|
||||
|
||||
**Symptom**: Loss stops decreasing early in training, well above expected values.
|
||||
|
||||
**Causes and fixes**:
|
||||
|
||||
- **Learning rate too low**: Increase by 2--5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
|
||||
- **Insufficient warmup**: Set `warmup_steps` to 5--10% of total steps. Too-aggressive learning at the start can push the model into a flat region.
|
||||
- **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable.
|
||||
- **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
|
||||
|
||||
### Loss Spikes
|
||||
|
||||
**Symptom**: Loss suddenly jumps by 2--10x then (possibly) recovers.
|
||||
|
||||
**Causes and fixes**:
|
||||
|
||||
- **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches.
|
||||
- **Learning rate too high**: Reduce by 2--5x, or increase warmup.
|
||||
- **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise.
|
||||
- **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis.
|
||||
|
||||
### Overfitting
|
||||
|
||||
**Symptom**: Train loss keeps decreasing but eval loss starts increasing.
|
||||
|
||||
**Fixes**:
|
||||
|
||||
- Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`.
|
||||
- Reduce `num_epochs` or `max_steps`.
|
||||
- Increase `weight_decay` (try 0.01--0.1).
|
||||
- Use a smaller LoRA rank (`lora_r`). Typical values: 8--32.
|
||||
- Increase dropout: `lora_dropout: 0.05`.
|
||||
|
||||
## RL/GRPO Stability
|
||||
|
||||
### Reward Never Increases
|
||||
|
||||
If `rewards/*/mean` stays at 0 for more than 20 steps:
|
||||
|
||||
1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values.
|
||||
```bash
|
||||
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
|
||||
```
|
||||
2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform.
|
||||
3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.
|
||||
4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config.
|
||||
|
||||
### Entropy Collapse (Mode Collapse)
|
||||
|
||||
**Symptom**: `entropy` drops below 0.01; all completions become nearly identical.
|
||||
|
||||
**Fixes**:
|
||||
|
||||
- Increase `temperature` in generation kwargs (try 0.8--1.0).
|
||||
- Reduce learning rate.
|
||||
- Add a KL penalty term (`beta` parameter in GRPO config).
|
||||
- Check that `num_generations` is sufficient (16+ gives better advantage estimates).
|
||||
|
||||
### IS Ratio Divergence
|
||||
|
||||
**Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0.
|
||||
|
||||
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
|
||||
|
||||
**Fixes**:
|
||||
|
||||
- Decrease `vllm_sync_interval` (sync weights more often).
|
||||
- Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples.
|
||||
- Use `importance_sampling_level: token` for finer-grained correction.
|
||||
|
||||
### Gradient Norm Instability
|
||||
|
||||
**Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly.
|
||||
|
||||
**Fixes**:
|
||||
|
||||
- Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs).
|
||||
- Reduce learning rate.
|
||||
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
|
||||
- Check for NaN issues (see next section).
|
||||
|
||||
## MoE Weight Scale Drift
|
||||
|
||||
**Symptom**: Model works on short prompts but loses coherence on long conversations — repeating itself, "philosophizing", or generating broken code. Particularly affects MoE models with recurrent/SSM components (e.g. DeltaNet linear attention).
|
||||
|
||||
**Root cause**: In MoE models trained with AdamW, rarely-activated experts accumulate smaller second-moment estimates. This gives them a disproportionately large effective learning rate, causing their weights to drift to higher variance than the group norm. In recurrent components like `conv1d` in DeltaNet layers, this amplifies short-range context and washes out long-range state.
|
||||
|
||||
**Detection**: Use `normalize_weight_scales` with `dry_run: true` to scan for anomalies without modifying weights:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
dry_run: true
|
||||
```
|
||||
|
||||
This logs any tensors matching the pattern whose standard deviation exceeds 1.3x the group median. Example output:
|
||||
|
||||
```
|
||||
normalize_weight_scales [DRY RUN]: pattern 'linear_attn\.conv1d\.weight' —
|
||||
3/30 tensors outside 1.3x threshold (median std=0.062733):
|
||||
layers.36.linear_attn.conv1d.weight: std=0.101870 (1.62x median)
|
||||
layers.37.linear_attn.conv1d.weight: std=0.102362 (1.63x median)
|
||||
layers.38.linear_attn.conv1d.weight: std=0.089227 (1.42x median)
|
||||
```
|
||||
|
||||
Each rule accepts:
|
||||
|
||||
- `name_pattern`: regex matched against parameter names. All matching tensors form a group.
|
||||
- `threshold`: flag tensors whose std deviates from the group median by more than this factor (default: 1.5).
|
||||
- `dry_run`: when `true`, log anomalies without modifying weights (default: `false`).
|
||||
|
||||
Multiple rules can target different tensor patterns:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
- name_pattern: 'experts\.gate_up_proj'
|
||||
threshold: 1.5
|
||||
dry_run: true # just check these, don't fix
|
||||
```
|
||||
|
||||
The transform runs after model loading but before adapter injection, so it modifies the base model weights directly.
|
||||
|
||||
## NaN and Inf Handling
|
||||
|
||||
### Common Causes
|
||||
|
||||
| Cause | Where It Manifests | Detection |
|
||||
|-------|-------------------|-----------|
|
||||
| FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately |
|
||||
| Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN |
|
||||
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
|
||||
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
|
||||
|
||||
### FP8-Specific NaN Issues
|
||||
|
||||
FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
|
||||
|
||||
**Fixes applied in axolotl**:
|
||||
|
||||
- The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`.
|
||||
- A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`.
|
||||
- Embedding padding is zero-padded for FP8 compatibility.
|
||||
|
||||
::: {.callout-important}
|
||||
## After Modifying Triton Kernels
|
||||
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
|
||||
|
||||
```bash
|
||||
rm -rf ~/.triton/cache
|
||||
```
|
||||
:::
|
||||
|
||||
### General NaN Debugging Steps
|
||||
|
||||
1. **Enable anomaly detection** (slow, but pinpoints the source):
|
||||
```python
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
```
|
||||
2. **Check grad_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.
|
||||
3. **Reduce to single GPU, single batch**: Eliminate distributed training variables.
|
||||
4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
|
||||
|
||||
## OOM Debugging
|
||||
|
||||
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
|
||||
|
||||
### Step 1: Reduce Batch Size
|
||||
|
||||
The single highest-impact change. VRAM scales roughly linearly with batch size.
|
||||
|
||||
```yaml
|
||||
micro_batch_size: 1 # Start here
|
||||
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
|
||||
```
|
||||
|
||||
For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is:
|
||||
|
||||
```
|
||||
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
|
||||
```
|
||||
|
||||
Reduce `micro_batch_size` to 2--4 for GRPO.
|
||||
|
||||
### Step 2: Enable Gradient Checkpointing
|
||||
|
||||
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false # Recommended default
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
## Reentrant Checkpointing Exceptions
|
||||
Some configurations require `use_reentrant: true`:
|
||||
|
||||
- DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`)
|
||||
- EBFT strided mode with flex_attention
|
||||
:::
|
||||
|
||||
### Step 3: Use Quantization
|
||||
|
||||
Load the base model in reduced precision:
|
||||
|
||||
```yaml
|
||||
# 4-bit QLoRA
|
||||
adapter: qlora
|
||||
load_in_4bit: true
|
||||
|
||||
# 8-bit
|
||||
load_in_8bit: true
|
||||
|
||||
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
|
||||
fp8: true
|
||||
```
|
||||
|
||||
### Step 4: Reduce Sequence Length
|
||||
|
||||
```yaml
|
||||
sequence_len: 1024 # Down from 2048 or 4096
|
||||
```
|
||||
|
||||
For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention.
|
||||
|
||||
### Step 5: Use Flash Attention
|
||||
|
||||
Reduces attention memory from O(n^2) to O(n):
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
```
|
||||
|
||||
### Step 6: Offload with DeepSpeed
|
||||
|
||||
For extreme cases, offload optimizer states or parameters to CPU:
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
```
|
||||
|
||||
### Diagnosing the Specific Culprit
|
||||
|
||||
Use the `profiler_steps` config option to capture GPU memory snapshots:
|
||||
|
||||
```yaml
|
||||
profiler_steps: [1, 2]
|
||||
```
|
||||
|
||||
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
|
||||
|
||||
## Common Errors
|
||||
|
||||
| Error Message | Likely Cause | Fix |
|
||||
|---------------|-------------|-----|
|
||||
| `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers |
|
||||
| `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` |
|
||||
| `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above |
|
||||
| `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](nccl.qmd); check `NCCL_DEBUG=INFO` output |
|
||||
| `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config |
|
||||
| `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config |
|
||||
| `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check |
|
||||
| `generation_batch_size not divisible by num_generations` | micro_batch_size too small | Set `micro_batch_size >= num_generations` and make it divisible |
|
||||
| `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled |
|
||||
| `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
| `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotl's `weight_serde.py` (auto bf16 to fp16 conversion) |
|
||||
| `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent |
|
||||
| `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` |
|
||||
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP |
|
||||
|
||||
## Profiling
|
||||
|
||||
### PyTorch Profiler
|
||||
|
||||
Axolotl supports PyTorch profiler integration via the config:
|
||||
|
||||
```yaml
|
||||
profiler_steps: [1, 2, 3]
|
||||
```
|
||||
|
||||
This captures profiler traces for the specified steps. View them in TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir output_dir/runs
|
||||
```
|
||||
|
||||
Or open the `.json` trace file in `chrome://tracing`.
|
||||
|
||||
### CUDA Memory Snapshots
|
||||
|
||||
For detailed memory analysis, use PyTorch's memory snapshot API. Add this to your training script or use it interactively:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Enable memory history tracking
|
||||
torch.cuda.memory._record_memory_history()
|
||||
|
||||
# ... run your training step ...
|
||||
|
||||
# Save snapshot
|
||||
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
||||
```
|
||||
|
||||
Visualize with PyTorch's memory visualizer:
|
||||
|
||||
```bash
|
||||
python -m torch.cuda.memory._viz memory_snapshot.pickle
|
||||
```
|
||||
|
||||
### Quick GPU Memory Check
|
||||
|
||||
During training, monitor GPU utilization in a separate terminal:
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorch's view of memory, which may differ from `nvidia-smi` (see [FAQ](faq.qmd)).
|
||||
|
||||
## W&B and Logging
|
||||
|
||||
### Enabling Logging
|
||||
|
||||
```yaml
|
||||
wandb_project: my-project
|
||||
wandb_entity: my-team # optional
|
||||
wandb_run_id: run-123 # optional, for resuming
|
||||
wandb_name: experiment-name # optional
|
||||
logging_steps: 1 # log every step (recommended for RL)
|
||||
```
|
||||
|
||||
### Debug Logging
|
||||
|
||||
For detailed axolotl-internal debug output:
|
||||
|
||||
```bash
|
||||
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
## Always Log to a File
|
||||
Pipe training output to a log file so you can inspect it after the run:
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
|
||||
```
|
||||
:::
|
||||
|
||||
### What Axolotl Logs
|
||||
|
||||
**SFT metrics** (logged every `logging_steps`):
|
||||
|
||||
- `train/loss`, `eval/loss` -- training and validation loss
|
||||
- `train/grad_norm` -- gradient L2 norm (before clipping)
|
||||
- `train/learning_rate` -- current learning rate
|
||||
- `memory/max_alloc`, `memory/max_reserved` -- peak GPU memory
|
||||
|
||||
**GRPO/RL metrics** (logged every step):
|
||||
|
||||
- `rewards/<name>/mean`, `rewards/<name>/std` -- per-reward-function statistics
|
||||
- `reward`, `reward_std` -- aggregated reward across all reward functions
|
||||
- `frac_reward_zero_std` -- fraction of prompt groups where all completions got the same reward
|
||||
- `completions/mean_length`, `completions/min_length`, `completions/max_length` -- completion token lengths
|
||||
- `completions/clipped_ratio` -- fraction of completions that hit the max length
|
||||
- `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` -- lengths of naturally terminated completions
|
||||
- `kl` -- KL divergence between policy and reference
|
||||
- `entropy` -- policy entropy (measure of output diversity)
|
||||
- `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` -- PPO clipping statistics
|
||||
- `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` -- log-probability difference between policy and sampling distribution
|
||||
- `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` -- IS ratio statistics for off-policy correction
|
||||
- `num_tokens` -- total tokens processed
|
||||
|
||||
### Reading W&B Charts
|
||||
|
||||
For a healthy GRPO run, expect to see:
|
||||
|
||||
1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.3--0.8 depending on task difficulty. Not monotonic -- fluctuations are normal.
|
||||
2. **`entropy`**: Gradual decrease from initial values (often 0.3--0.6) as the model becomes more confident. Should not collapse to near-zero.
|
||||
3. **`grad_norm`**: Mostly in the 0.001--1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
|
||||
4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
|
||||
5. **`completions/mean_length`**: Should reflect the task's natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.
|
||||
@@ -1,318 +0,0 @@
|
||||
---
|
||||
title: "vLLM Serving for GRPO Training"
|
||||
description: "How to configure and run vLLM as a generation backend for GRPO reinforcement learning in Axolotl."
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview {#sec-overview}
|
||||
|
||||
GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation.
|
||||
|
||||
Axolotl uses [vLLM](https://github.com/vllm-project/vllm) as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates.
|
||||
|
||||
```
|
||||
┌──────────────────────┐ HTTP ┌──────────────────────┐
|
||||
│ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│
|
||||
│ │ prompts/compls │ │
|
||||
│ - Policy model │ ◄──────────────── │ - Same base model │
|
||||
│ - Reward scoring │ │ - Fast generation │
|
||||
│ - Gradient updates │ weight sync │ - LoRA adapter │
|
||||
│ - LoRA adapter │ ─────────────────►│ (periodically │
|
||||
│ │ (every N steps) │ updated) │
|
||||
└──────────────────────┘ └──────────────────────┘
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
vLLM must serve the **same base model** specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results.
|
||||
:::
|
||||
|
||||
## Server Mode {#sec-server-mode}
|
||||
|
||||
Server mode runs vLLM as an external process on dedicated GPU(s). This is the recommended configuration for most setups.
|
||||
|
||||
### Starting the Server
|
||||
|
||||
Use the `axolotl vllm-serve` command with your training config:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 2: Start training on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
|
||||
```
|
||||
|
||||
The server reads vLLM settings from the `vllm:` section of your config and starts an HTTP server (default: `http://0.0.0.0:8000`).
|
||||
|
||||
::: {.callout-tip}
|
||||
Use `tmux` or `screen` to manage the vLLM server process. Typical startup time is 30-90 seconds depending on model size and whether CUDA graphs are captured.
|
||||
:::
|
||||
|
||||
### Minimal Server Config
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
max_model_len: 4096
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
```
|
||||
|
||||
### Multi-GPU vLLM
|
||||
|
||||
For larger models, use tensor parallelism across multiple GPUs:
|
||||
|
||||
```yaml
|
||||
vllm:
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
```
|
||||
|
||||
```bash
|
||||
# vLLM on GPUs 2,3; training on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo_config.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo_config.yaml --num-processes 2
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Due to how TRL maps vLLM device indices, the vLLM instance should use the **last** N GPUs (highest device indices), while training uses the first N.
|
||||
:::
|
||||
|
||||
## Colocate Mode {#sec-colocate-mode}
|
||||
|
||||
Colocate mode runs vLLM on the same GPU as the trainer. This is useful when you only have a single GPU.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: colocate
|
||||
vllm_enable_sleep_mode: true
|
||||
```
|
||||
|
||||
With `vllm_enable_sleep_mode: true`, vLLM offloads its VRAM allocation when not actively generating, freeing memory for training. When the trainer needs new completions, vLLM wakes up and reclaims VRAM.
|
||||
|
||||
::: {.callout-warning}
|
||||
Colocate mode is significantly slower than server mode because generation and training cannot overlap. The GPU alternates between the two workloads. This mode is practical only for smaller models (up to ~3B on a 24 GB GPU).
|
||||
:::
|
||||
|
||||
**When to use colocate mode:**
|
||||
|
||||
- You have exactly one GPU
|
||||
- The model fits in memory with both vLLM and training active (with sleep mode), or is small enough to time-share
|
||||
- You accept the performance tradeoff for simpler setup (no separate vLLM process to manage)
|
||||
|
||||
**When to use server mode:**
|
||||
|
||||
- You have two or more GPUs
|
||||
- You want maximum throughput (generation overlaps with training via async prefetch)
|
||||
- You are running larger models (7B+)
|
||||
|
||||
## LoRA Sync {#sec-lora-sync}
|
||||
|
||||
LoRA sync is the recommended weight synchronization method when training with LoRA adapters. Instead of merging adapter weights into the base model and broadcasting the full merged weights over NCCL, it saves only the LoRA adapter files to the filesystem and tells vLLM to load them natively.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. The trainer calls `model.save_pretrained()` to write the LoRA adapter weights to a temporary directory
|
||||
2. The trainer sends an HTTP POST to `/set_lora_adapter/` on the vLLM server
|
||||
3. vLLM loads the adapter using its native LoRA support (Punica kernels)
|
||||
4. Generation uses the updated adapter on the next request
|
||||
|
||||
### Benefits
|
||||
|
||||
- **Smaller sync payload**: Transfers ~40 MB of LoRA weights instead of ~1.4 GB+ of merged model weights (for a typical 0.5-3B model)
|
||||
- **No NCCL communicator**: Eliminates the need for a cross-GPU NCCL communication channel, removing GPU contention between vLLM generation and weight sync
|
||||
- **Faster sync**: ~200 ms per sync vs. 350 ms to 5+ seconds for NCCL merge sync
|
||||
- **Simpler multi-GPU**: No need to set up NCCL groups between trainer and vLLM processes
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
vllm_lora_sync: true # Enables LoRA sync mode
|
||||
vllm_sync_interval: 5 # Sync every 5 training steps
|
||||
```
|
||||
|
||||
Setting `vllm_lora_sync: true` automatically selects the LoRA-aware vLLM serve script (`axolotl.scripts.vllm_serve_lora`). You do not need to set `vllm.serve_module` manually.
|
||||
|
||||
::: {.callout-important}
|
||||
LoRA sync requires that you are training with a LoRA adapter (`adapter: lora` or `adapter: qlora`). It is not applicable to full fine-tuning.
|
||||
:::
|
||||
|
||||
## Weight Synchronization {#sec-weight-sync}
|
||||
|
||||
During GRPO training, the policy model on the trainer is continuously updated via gradient steps. The vLLM server, however, still holds the old weights. Periodically, the trainer must push updated weights to vLLM so that future generations reflect the improved policy.
|
||||
|
||||
### Sync Interval
|
||||
|
||||
The `vllm_sync_interval` parameter controls how often weights are synced:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_sync_interval: 5 # Sync every 5 optimizer steps
|
||||
```
|
||||
|
||||
**Tradeoffs:**
|
||||
|
||||
- **Lower interval** (e.g., 1-3): Fresher generations, better on-policy data, but more sync overhead per step
|
||||
- **Higher interval** (e.g., 5-10): Less overhead, but generations become increasingly off-policy between syncs
|
||||
- **Recommended**: 3-5 for most setups. Axolotl includes importance sampling correction (`vllm_importance_sampling_correction: true`) to handle mild distribution mismatch from stale vLLM weights.
|
||||
|
||||
### Sync Methods
|
||||
|
||||
| Method | Config | Payload | Mechanism | Typical Time |
|
||||
|--------|--------|---------|-----------|-------------|
|
||||
| **LoRA sync** | `vllm_lora_sync: true` | LoRA adapter only (~40 MB) | Filesystem + HTTP | ~200 ms |
|
||||
| **NCCL merge sync** | Default (no lora_sync) | Full merged weights (~1.4 GB+) | HTTP trigger + NCCL broadcast | 350 ms - 5 s |
|
||||
|
||||
::: {.callout-tip}
|
||||
If you are training with LoRA (which is recommended for GRPO), always enable `vllm_lora_sync: true`. The performance difference is substantial, especially as training progresses and NCCL contention increases.
|
||||
:::
|
||||
|
||||
### Importance Sampling Correction
|
||||
|
||||
When vLLM weights are stale (between syncs), the generated data is slightly off-policy. Axolotl can correct for this:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true
|
||||
importance_sampling_level: token # 'token' or 'sequence'
|
||||
off_policy_mask_threshold: 0.5 # KL threshold for masking stale sequences
|
||||
```
|
||||
|
||||
- **Token-level IS** is recommended when using Liger kernel (sequence-level has numerical issues with chunked computation)
|
||||
- **Off-policy sequence masking (OPSM)** drops sequences that have diverged too far from the current policy, providing a safety net against stale data
|
||||
|
||||
## Restart Requirements {#sec-restart}
|
||||
|
||||
::: {.callout-warning}
|
||||
**vLLM must be restarted between training runs.** Weight syncs from a previous run leave the server in a corrupted state. If you start a new training run against a stale vLLM server, the model may fail to learn.
|
||||
:::
|
||||
|
||||
### When to Restart
|
||||
|
||||
- Before every new training experiment
|
||||
- After a training run crashes or is interrupted
|
||||
- If you change the base model in your config
|
||||
|
||||
### How to Restart
|
||||
|
||||
Killing vLLM reliably requires terminating both the main process and its background EngineCore subprocess:
|
||||
|
||||
```bash
|
||||
# Kill all vLLM-related processes
|
||||
pkill -9 -f "vllm|EngineCore"
|
||||
|
||||
# Verify GPU memory is freed
|
||||
nvidia-smi
|
||||
|
||||
# Restart the server
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
A single `kill` often does not fully stop vLLM. Always use `kill -9` and verify with `nvidia-smi` that GPU memory has been released before restarting.
|
||||
:::
|
||||
|
||||
### Health Check
|
||||
|
||||
The vLLM server exposes a health endpoint. Wait for it to return 200 before starting training:
|
||||
|
||||
```bash
|
||||
# For the LoRA serve script (trailing slash required)
|
||||
curl http://localhost:8000/health/
|
||||
|
||||
# For the default TRL serve script
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
## Configuration Reference {#sec-config-reference}
|
||||
|
||||
### vLLM Server Options (`vllm:` section)
|
||||
|
||||
These control the vLLM server process started by `axolotl vllm-serve`.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `host` | str | `0.0.0.0` | Host address for the vLLM server |
|
||||
| `port` | int | `8000` | Port for the vLLM server |
|
||||
| `device` | str | `auto` | Device to use for vLLM |
|
||||
| `tensor_parallel_size` | int | `None` | Number of GPUs for tensor parallelism |
|
||||
| `data_parallel_size` | int | `None` | Number of data parallel replicas |
|
||||
| `gpu_memory_utilization` | float | `0.9` | Fraction of GPU memory for vLLM (0.0-1.0) |
|
||||
| `dtype` | str | `auto` | Data type (`auto`, `float16`, `bfloat16`) |
|
||||
| `max_model_len` | int | `None` | Maximum model context length. Set explicitly if the default is too large for your GPU |
|
||||
| `enable_prefix_caching` | bool | `None` | Enable prefix caching for repeated prompt prefixes |
|
||||
| `enable_reasoning` | bool | `None` | Enable reasoning mode for models with thinking tokens |
|
||||
| `reasoning_parser` | str | `None` | Parser for reasoning output |
|
||||
| `enforce_eager` | bool | `None` | Disable CUDA graph capture (required for some architectures like Qwen3.5 hybrid attention) |
|
||||
| `serve_module` | str | `None` | Python module for vLLM serve script. Auto-set when `vllm_lora_sync: true` |
|
||||
| `worker_extension_cls` | str | `None` | vLLM worker extension class for weight sync |
|
||||
|
||||
### Trainer vLLM Options (`trl:` section)
|
||||
|
||||
These control how the trainer interacts with vLLM.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `use_vllm` | bool | `false` | Enable vLLM for generation |
|
||||
| `vllm_mode` | str | `None` | `server` (external process) or `colocate` (same GPU) |
|
||||
| `vllm_server_host` | str | `0.0.0.0` | Host of the vLLM server to connect to |
|
||||
| `vllm_server_port` | int | `8000` | Port of the vLLM server to connect to |
|
||||
| `vllm_server_timeout` | int | `None` | Timeout in seconds for vLLM requests |
|
||||
| `vllm_lora_sync` | bool | `false` | Sync LoRA adapters via filesystem instead of NCCL merge |
|
||||
| `vllm_sync_interval` | int | `None` | Sync weights every N optimizer steps |
|
||||
| `vllm_enable_sleep_mode` | bool | `None` | Offload vLLM VRAM when idle (colocate mode) |
|
||||
| `vllm_guided_decoding_regex` | str | `None` | Regex constraint for guided decoding |
|
||||
|
||||
For async pipeline and off-policy correction options, see the [GRPO Configuration Reference](grpo.qmd#configuration-reference).
|
||||
|
||||
## Complete Example {#sec-complete-example}
|
||||
|
||||
For a full working GRPO config including vLLM, LoRA sync, async generation, rewards, and dataset setup, see the [GRPO Quick Start](grpo.qmd#quick-start). That config includes all the vLLM settings covered in this guide.
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
||||
|
||||
# Wait for health check to pass
|
||||
curl http://localhost:8000/health/
|
||||
|
||||
# Terminal 2: Start training
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
| Problem | Likely Cause | Solution |
|
||||
|---------|-------------|----------|
|
||||
| Training hangs waiting for vLLM | Server not started or wrong port | Check `curl http://localhost:8000/health/` and verify `vllm_server_host`/`vllm_server_port` match |
|
||||
| OOM on vLLM GPU | `gpu_memory_utilization` too high or `max_model_len` too large | Reduce `gpu_memory_utilization` to 0.7 or set `max_model_len` explicitly |
|
||||
| OOM on training GPU | Batch too large for policy logprobs | Reduce `micro_batch_size` or `num_generations` |
|
||||
| Accuracy stays at zero | Stale vLLM from previous run | Restart vLLM: `pkill -9 -f "vllm\|EngineCore"`, verify with `nvidia-smi`, restart |
|
||||
| `ResponseValidationError` from vLLM | Missing logprobs in response | Ensure you are using the correct serve module (auto-selected with `vllm_lora_sync: true`) |
|
||||
| Weight sync takes 5+ seconds | NCCL contention with vLLM generation | Switch to `vllm_lora_sync: true` to eliminate NCCL |
|
||||
| `async_prefetch` deadlocks with FSDP | Background threads run unsynchronized FSDP collectives | Set `async_prefetch: false` when using FSDP or DeepSpeed multi-GPU |
|
||||
@@ -1,211 +0,0 @@
|
||||
# Energy-Based Fine-Tuning (EBFT)
|
||||
|
||||
EBFT is an integration of ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) into axolotl.
|
||||
|
||||
## Overview
|
||||
|
||||
EBFT fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
||||
|
||||
**Key advantages over SFT:**
|
||||
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
||||
- Provides dense sequence-level supervision without a task-specific verifier
|
||||
- Improves both downstream accuracy and validation cross-entropy simultaneously
|
||||
|
||||
**Key advantages over RLVR:**
|
||||
- No reward model or verifier required — works on any (prompt, completion) data
|
||||
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
|
||||
- Maintains distributional calibration (low feature-matching loss)
|
||||
|
||||
## Two Modes
|
||||
|
||||
EBFT supports two modes depending on your data format:
|
||||
|
||||
### Structured Mode (`mode: structured`, default)
|
||||
For **QA/instruction data** with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
|
||||
- Extends GRPOTrainer — uses vLLM for fast rollout generation
|
||||
- RLOO advantages and clipped policy gradient from GRPO
|
||||
- Feature-matching rewards replace external reward functions
|
||||
|
||||
### Strided Mode (`mode: strided`)
|
||||
For **unstructured text** without prompt/completion splits (e.g., raw code, prose, SwallowCode).
|
||||
- Uses **strided block-parallel generation** — multiple short rollouts at different anchor points within a document
|
||||
- No vLLM needed — generation uses custom strided attention masks
|
||||
- Uses **torch flex_attention** with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
|
||||
- Compatible with gradient checkpointing via automatic dtype normalization
|
||||
- This is the core EBFT algorithm from the paper (Section F)
|
||||
|
||||
### Common to both modes:
|
||||
- **Frozen feature network** — deep copy of the model at initialization (frozen, eval mode)
|
||||
- **Feature extraction** — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
|
||||
- **Feature-matching rewards** — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
|
||||
- **SVD whitening** — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
|
||||
- **CFM loss tracking** — conditional feature-matching loss (paper eq 2) logged as `ebft/cfm_loss`
|
||||
- **FSDP2 compatible** — feature network stays outside FSDP wrapping (frozen, inference-only)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Structured Mode (QA data + vLLM)
|
||||
|
||||
```bash
|
||||
# 1. Start vLLM server (LoRA serve module auto-selected when vllm_lora_sync: true)
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen3-4b-ebft-structured-async.yaml
|
||||
|
||||
# 2. Train on a separate GPU
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train examples/ebft/qwen3-4b-ebft-structured-async.yaml
|
||||
```
|
||||
|
||||
### Strided Mode (unstructured text)
|
||||
|
||||
```bash
|
||||
# No vLLM needed — strided generation is built-in
|
||||
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Common EBFT Settings
|
||||
|
||||
```yaml
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
# Feature network: which layers to extract hidden states from
|
||||
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
|
||||
# How to pool per-token hidden states into sequence embeddings
|
||||
# Options: "last_token" (recommended), "mean_pooling", "concat"
|
||||
embed_method: last_token
|
||||
|
||||
# SVD whitening — strongly recommended (paper shows largest degradation without it)
|
||||
use_whitening: true
|
||||
|
||||
# Reward = alignment_coef * alignment - diversity_coef * diversity
|
||||
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
|
||||
# diversity uses raw dot product — both are bounded after whitening.
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
|
||||
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
|
||||
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
|
||||
ce_coef: 0.0
|
||||
```
|
||||
|
||||
### Strided Mode Settings
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8 # tokens between anchor points (paper default: 8)
|
||||
context_length: 8 # context window per block (paper default: 8)
|
||||
generate_max_len: 8 # tokens generated per block (paper default: 8)
|
||||
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
|
||||
temperature: 0.6
|
||||
rl_coef: 1.0 # RL loss weight
|
||||
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
|
||||
```
|
||||
|
||||
### Structured Mode Settings (via TRL)
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt
|
||||
max_completion_length: 256 # max tokens to generate
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
```
|
||||
|
||||
### Dataset Format
|
||||
|
||||
**Structured mode** — QA data with prompt + ground-truth completion:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
```
|
||||
Transform returns: `{"prompt": ..., "ground_truth": ...}`
|
||||
|
||||
**Strided mode** — raw text tokenized to fixed length:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
```
|
||||
Transform returns: `{"input_ids": ..., "attention_mask": ..., "labels": ...}`
|
||||
|
||||
## How It Works
|
||||
|
||||
### Structured Mode
|
||||
1. **Generate**: For each prompt, generate `num_generations` completions via vLLM
|
||||
2. **Extract features**: Forward both generated and ground-truth sequences through the frozen feature network
|
||||
3. **Compute rewards**: `2 * alignment - 2 * diversity` (paper eq 7)
|
||||
4. **RLOO advantages**: subtract leave-one-out group mean
|
||||
5. **Policy gradient**: clipped PPO-style loss
|
||||
|
||||
### Strided Mode
|
||||
1. **Anchor selection**: Pick `num_blocks = (seq_len - gen_len - ctx_len) / stride + 1` anchor points across the document
|
||||
2. **Block-parallel generation**: At each anchor, generate `gen_len` tokens using a custom strided attention mask via `flex_attention` compiled block masks
|
||||
3. **Feature extraction**: Forward the full sequence (prompt + generated) through the frozen feature network **with the strided attention mask** — this is critical for correct feature representations
|
||||
4. **Per-block rewards**:
|
||||
- **Alignment** = `2 * cosine_similarity(gen_block_emb, gt_block_emb)` — normalized, bounded in [-2, 2]
|
||||
- **Diversity** = `2 * mean_pairwise_dot_product(gen_block_embs)` — raw dot product on whitened vectors
|
||||
- **Reward** = `alignment_coef * alignment - diversity_coef * diversity`
|
||||
5. **RLOO advantages**: leave-one-out baseline across `n_samples_per_prompt` rollouts per block
|
||||
6. **Policy gradient**: REINFORCE loss on generated tokens, weighted by per-block advantages
|
||||
|
||||
### Tracked Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| `ebft/alignment` | Mean cosine similarity between generated and GT features (higher = better) |
|
||||
| `ebft/diversity` | Mean pairwise similarity between samples (lower = more diverse) |
|
||||
| `ebft/mean_reward` | alignment - diversity (should trend upward) |
|
||||
| `ebft/cfm_loss` | Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
|
||||
| `ebft/rl_loss` | REINFORCE policy gradient loss |
|
||||
| `ebft/ce_loss` | Cross-entropy loss on ground-truth tokens (when `ce_coef > 0`) |
|
||||
| `ebft/advantages_std` | RLOO advantage standard deviation (should be non-zero) |
|
||||
|
||||
## Tips and Recommendations
|
||||
|
||||
### Reward coefficients
|
||||
- **`use_whitening: true`**: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with `diversity_coef > 0`.
|
||||
- **`diversity_coef`**: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
|
||||
- **`n_samples_per_prompt`**: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
|
||||
- **`ce_coef`**: The paper ablates `γ ∈ {0, 0.03, 0.1}`. `0.03` balances CE and RL signals; `0.1` causes CE to dominate the gradient. `0.0` gives pure feature matching.
|
||||
|
||||
### Feature extraction
|
||||
- **`feature_layers: [0.25, 0.5, 0.75]`**: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
|
||||
- **`embed_method: last_token`**: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
|
||||
|
||||
### Performance
|
||||
- **`torch_compile: true`**: Recommended for strided mode. Provides additional speedup via graph compilation.
|
||||
- **flex_attention**: Strided mode automatically uses `flex_attention` with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
|
||||
|
||||
### Memory
|
||||
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
|
||||
- **LoRA** is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
|
||||
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
|
||||
- **FSDP2** is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With `cpu_ram_efficient_loading`, the feature network is loaded separately from pretrained weights.
|
||||
|
||||
## Example Configs
|
||||
|
||||
| Config | Mode | Model | Description |
|
||||
|--------|------|-------|-------------|
|
||||
| `llama-1b-ebft-opencode.yaml` | Structured | Llama-3.2-1B | QA coding with vLLM |
|
||||
| `llama-1b-ebft-opencode-novllm.yaml` | Structured | Llama-3.2-1B | QA coding without vLLM |
|
||||
| `llama-3b-ebft-strided-fft.yaml` | Strided | Llama-3.2-3B | Unstructured code with LoRA |
|
||||
| `llama-1b-ebft-strided.yaml` | Strided | Llama-3.2-1B | Quick validation |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{jelassi2026matching,
|
||||
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
|
||||
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
|
||||
journal={arXiv preprint arXiv:2603.12248},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
Dataset transform for nvidia/OpenCodeInstruct with EBFT.
|
||||
|
||||
Maps the dataset's `input` (prompt) and `output` (code solution) fields
|
||||
to the format expected by the EBFT trainer.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "user", "content": example["input"]},
|
||||
],
|
||||
"ground_truth": example["output"],
|
||||
}
|
||||
|
||||
return transform_fn, {
|
||||
"remove_columns": [
|
||||
"id",
|
||||
"domain",
|
||||
"generation_algorithm",
|
||||
"llm_judgement",
|
||||
"unit_tests",
|
||||
"tests_execution_status",
|
||||
"average_test_score",
|
||||
]
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
"""
|
||||
Dataset transform for unstructured text data with strided EBFT.
|
||||
|
||||
Tokenizes raw text into fixed-length input_ids for the strided trainer.
|
||||
Sequences are padded to sequence_len for uniform batching.
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
text = example.get("question", example.get("text", ""))
|
||||
if tokenizer is None:
|
||||
return {"prompt": text}
|
||||
|
||||
encoded = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
padding="max_length",
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"labels": list(encoded["input_ids"]),
|
||||
}
|
||||
|
||||
return transform_fn, {"remove_columns": ["question", "answer"]}
|
||||
@@ -1,80 +0,0 @@
|
||||
"""
|
||||
Dataset transform for structured (prompt, completion) data with strided EBFT.
|
||||
|
||||
Tokenizes prompt and completion separately, concatenates into a single
|
||||
input_ids sequence, and marks prompt tokens with labels=-100 so the
|
||||
strided trainer knows where to place anchors (completion span only).
|
||||
|
||||
Works with datasets that have chat-style fields (e.g., nvidia/OpenCodeInstruct).
|
||||
"""
|
||||
|
||||
|
||||
def transform(cfg, *args, **kwargs):
|
||||
seq_len = cfg.sequence_len
|
||||
|
||||
def transform_fn(example, tokenizer=None):
|
||||
# Extract prompt and completion from the example
|
||||
prompt_text = example.get(
|
||||
"input", example.get("prompt", example.get("question", ""))
|
||||
)
|
||||
completion_text = example.get(
|
||||
"output", example.get("completion", example.get("answer", ""))
|
||||
)
|
||||
|
||||
if tokenizer is None:
|
||||
return {"prompt": prompt_text}
|
||||
|
||||
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
||||
|
||||
# Tokenize prompt and completion separately
|
||||
prompt_enc = tokenizer(
|
||||
prompt_text,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
completion_enc = tokenizer(
|
||||
completion_text,
|
||||
truncation=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
prompt_ids = prompt_enc["input_ids"]
|
||||
completion_ids = completion_enc["input_ids"]
|
||||
|
||||
# Truncate to fit within seq_len (prioritize keeping prompt + some completion)
|
||||
total_len = len(prompt_ids) + len(completion_ids)
|
||||
if total_len > seq_len:
|
||||
# Truncate completion first, then prompt if needed
|
||||
max_completion = seq_len - len(prompt_ids)
|
||||
if max_completion < 1:
|
||||
# Prompt alone exceeds seq_len — truncate prompt, keep at least 1 completion token
|
||||
prompt_ids = prompt_ids[: seq_len - 1]
|
||||
completion_ids = completion_ids[:1]
|
||||
else:
|
||||
completion_ids = completion_ids[:max_completion]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
# Labels: -100 for prompt tokens, input_ids for completion tokens
|
||||
labels = [-100] * prompt_length + completion_ids
|
||||
|
||||
# Pad to seq_len
|
||||
pad_len = seq_len - len(input_ids)
|
||||
attention_mask = [1] * len(input_ids) + [0] * pad_len
|
||||
labels = labels + [-100] * pad_len
|
||||
input_ids = input_ids + [pad_id] * pad_len
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"prompt_length": prompt_length,
|
||||
}
|
||||
|
||||
# Signal to remove all original columns (filtered to existing ones at map time)
|
||||
return transform_fn, {
|
||||
"remove_columns": "__all__",
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
# EBFT validation config — no vLLM, uses HF generate for simplicity
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode-novllm.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 128
|
||||
temperature: 1.0
|
||||
use_vllm: false
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 512
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-validation
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_run_id:
|
||||
wandb_watch:
|
||||
wandb_log_model:
|
||||
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
@@ -1,81 +0,0 @@
|
||||
# EBFT: Energy-Based Fine-Tuning with Llama-3.2-1B on OpenCodeInstruct
|
||||
#
|
||||
# Paper: "Matching Features, Not Tokens" (Jelassi et al., 2026)
|
||||
# https://arxiv.org/abs/2603.12248
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM server on a separate GPU:
|
||||
# CUDA_VISIBLE_DEVICES=1 python -m trl.scripts.vllm_serve \
|
||||
# --model meta-llama/Llama-3.2-1B \
|
||||
# --host 0.0.0.0 --port 8000 \
|
||||
# --gpu-memory-utilization 0.4 --dtype bfloat16
|
||||
#
|
||||
# 2. Run training:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-opencode.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
chat_template: llama3
|
||||
|
||||
# --- Training method ---
|
||||
rl: ebft
|
||||
|
||||
# --- EBFT configuration ---
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # extract hidden states at 25%, 50%, 75% depth
|
||||
embed_method: last_token # pool to sequence embedding via last token
|
||||
use_whitening: false # SVD whitening (disable for speed in small runs)
|
||||
alignment_coef: 1.0 # cosine similarity with ground-truth features
|
||||
diversity_coef: 1.0 # pairwise similarity penalty
|
||||
ce_coef: 0.0 # cross-entropy on ground-truth (0 = pure feature matching)
|
||||
|
||||
# --- Generation settings (via TRL/GRPO infrastructure) ---
|
||||
trl:
|
||||
num_generations: 4 # samples per prompt for RLOO
|
||||
max_completion_length: 256 # max generated tokens
|
||||
temperature: 1.0
|
||||
use_vllm: true
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
|
||||
# --- Dataset ---
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:1%] # first 1% for validation runs
|
||||
|
||||
# --- Training hyperparameters ---
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 50
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
weight_decay: 0.01
|
||||
|
||||
# --- LoRA (recommended to reduce memory with frozen feature network) ---
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# --- Hardware ---
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama-1b-opencode
|
||||
|
||||
# --- Logging ---
|
||||
use_tensorboard: true
|
||||
logging_steps: 1
|
||||
save_steps: 25
|
||||
@@ -1,65 +0,0 @@
|
||||
# EBFT Strided Structured Mode: For structured (prompt, completion) data
|
||||
# Uses strided block-parallel generation on completion spans — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided-structured.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.03 # small CE weight for structured data
|
||||
advantage_estimator: rloo
|
||||
min_completion_prefix: 8 # skip anchors too close to prompt boundary
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 2048
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
# max_steps: 10
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 5
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # required for flex_attention (non-reentrant causes CheckpointError)
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
@@ -1,60 +0,0 @@
|
||||
# EBFT Strided Mode: For unstructured text data (raw code, prose)
|
||||
# Uses strided block-parallel generation — no vLLM needed.
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 axolotl train examples/ebft/llama-1b-ebft-strided.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided # strided block-parallel generation
|
||||
stride: 8 # tokens between anchor points
|
||||
context_length: 8 # context window per block
|
||||
generate_max_len: 8 # tokens to generate per block
|
||||
n_samples_per_prompt: 4 # rollouts per document
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:100]
|
||||
|
||||
sequence_len: 256
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
max_steps: 5
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 2
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-strided-validation
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 100
|
||||
@@ -1,69 +0,0 @@
|
||||
# EBFT Strided: LoRA Llama-3.2-3B on SwallowCode, 100 steps
|
||||
# Actor on GPU 0, frozen feature network on GPU 1
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0,1 python -m axolotl.cli.train examples/ebft/llama-3b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0 # paper recommends 0.03 for mixed objective; 0.1 causes CE to dominate
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-5
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
torch_dtype: bfloat16
|
||||
flash_attention: false
|
||||
gradient_checkpointing: true
|
||||
torch_compile: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
ddp: false
|
||||
device_map:
|
||||
"": 0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama3b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama3b-strided-lora-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -1,58 +0,0 @@
|
||||
# EBFT Strided: Full-parameter Llama-3.1-8B on SwallowCode, 100 steps
|
||||
# Feature network is CPU-offloaded to fit in single 32GB GPU
|
||||
#
|
||||
# Run: CUDA_VISIBLE_DEVICES=0 python -m axolotl.cli.train examples/ebft/llama-8b-ebft-strided-fft.yaml
|
||||
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8
|
||||
context_length: 8
|
||||
generate_max_len: 8
|
||||
n_samples_per_prompt: 4
|
||||
temperature: 0.6
|
||||
top_p: 1.0
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
advantage_estimator: rloo
|
||||
|
||||
datasets:
|
||||
- path: sjelassi/swallow_code_20m
|
||||
type: ebft_pretrain.transform
|
||||
split: train[:5000]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
learning_rate: 1.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-llama8b-strided
|
||||
|
||||
wandb_project: ebft
|
||||
wandb_name: llama8b-strided-fft-100steps
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -1,86 +0,0 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured-async.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
async_prefetch: true
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured-async
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -1,77 +0,0 @@
|
||||
# EBFT Structured Mode: Qwen3.5-4B (hybrid linear attention)
|
||||
#
|
||||
# Qwen3.5 uses hybrid attention: linear attention (conv1d) on 3/4 of layers,
|
||||
# full attention every 4th layer. This tests EBFT compatibility.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3.5-4B \
|
||||
# --gpu-memory-utilization 0.5 --max-model-len 2048 --enforce-eager
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-4b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-4B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false # disable Qwen3.5 thinking mode for shorter completions
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-4b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -1,82 +0,0 @@
|
||||
# EBFT Structured Mode: Qwen3.5-9B (hybrid linear attention)
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1. Start vLLM on GPU 0:
|
||||
# CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
#
|
||||
# 2. Run training on GPU 1:
|
||||
# CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
# axolotl train examples/ebft/qwen35-9b-ebft-structured.yaml
|
||||
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: false
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
ce_coef: 0.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
epsilon: 0.2
|
||||
generation_kwargs:
|
||||
stop_token_ids: [248044, 248046] # <|endoftext|>, <|im_end|>
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
vllm_server_timeout: 300
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.7
|
||||
max_model_len: 2048
|
||||
serve_module: axolotl.scripts.vllm_serve_lora
|
||||
enforce_eager: true
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
|
||||
learning_rate: 3.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 3
|
||||
weight_decay: 0.01
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
# Target full-attention q/k/v/o on layers 3,7,11,15,19,23,27,31 + MLP on all layers
|
||||
# Avoids linear_attn modules (in_proj_qkv, in_proj_z, etc.) which break vLLM LoRA
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/ebft-qwen35-9b-structured
|
||||
|
||||
wandb_project: ebft
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
@@ -1,104 +0,0 @@
|
||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||
#
|
||||
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
|
||||
#
|
||||
# Key notes:
|
||||
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
|
||||
# Use sdp_attention instead.
|
||||
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
|
||||
# LoRA to the text backbone via lora_target_linear_modules regex.
|
||||
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
|
||||
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
|
||||
# via the transformers ExpertsInterface.
|
||||
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
|
||||
# (no `mlp.` prefix, unlike Qwen/Mixtral).
|
||||
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
|
||||
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
|
||||
|
||||
base_model: google/gemma-4-26B-A4B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
torch_compile: false
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-26b-a4b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders).
|
||||
# lora_target_modules is intentionally empty — all module targeting is done
|
||||
# via regex in lora_target_linear_modules below.
|
||||
lora_target_modules: []
|
||||
lora_target_linear_modules:
|
||||
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
|
||||
|
||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project: gemma4-qlora
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
|
||||
flash_attention: false
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -58,14 +58,6 @@ datasets:
|
||||
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
### GGUF / llama.cpp loading error (missing tensors)
|
||||
|
||||
If you see `missing tensor 'blk.X.attn_norm.weight'` when loading a GLM-4 / GLM4-MoE model in llama.cpp, this is likely
|
||||
caused by `num_nextn_predict_layers` being set to `1` in `config.json` while the MTP weights were not exported (possible
|
||||
after PEFT/QLoRA training).
|
||||
|
||||
**Fix:** Set `"num_nextn_predict_layers": 0` in your `config.json` before converting to GGUF.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
# Attention projection layers (present in ~12 attention layers out of 88)
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
@@ -1,48 +0,0 @@
|
||||
# Nemotron-H (nvidia/NVIDIA-Nemotron-3-*)
|
||||
|
||||
Hybrid Mamba2 / Attention / MoE architecture (`model_type: nemotron_h`).
|
||||
|
||||
| Model | Total params | Active params | Layers |
|
||||
|---|---|---|---|
|
||||
| NVIDIA-Nemotron-3-Super-120B-A12B-BF16 | 120B | ~12B | 88 |
|
||||
| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 | 30B | ~3B | — |
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install mamba-ssm causal-conv1d # fast Mamba2 CUDA kernels
|
||||
```
|
||||
|
||||
## Architecture notes
|
||||
|
||||
- Three block types per layer: **Mamba2** (selective SSM), **Attention** (sparse), **MoE** (mixture-of-experts).
|
||||
- Only ~12 out of 88 blocks are attention layers (120B variant).
|
||||
- MLP activation is `relu2` via `mlp_hidden_act` (not the usual `hidden_act`).
|
||||
|
||||
## LoRA kernel patches
|
||||
|
||||
All three LoRA Triton kernel patches must be disabled:
|
||||
|
||||
```yaml
|
||||
lora_qkv_kernel: false # attention lives in NemotronHBlock.mixer, not layer.self_attn
|
||||
lora_o_kernel: false # same reason
|
||||
lora_mlp_kernel: false # relu2 (mlp_hidden_act) is not supported by lora_mlp_kernel
|
||||
```
|
||||
|
||||
## MoE expert weights
|
||||
|
||||
NemotronH experts store `up_proj` and `down_proj` as 3D `nn.Parameter` tensors
|
||||
(shape `[num_experts, out_dim, in_dim]`), **not** `nn.Linear` modules — there is no
|
||||
`gate_proj`. To fine-tune them alongside attention, use `lora_target_parameters`
|
||||
instead of `lora_target_modules`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- up_proj
|
||||
- down_proj
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- **MoE Triton kernels**: `lora_mlp_kernel` is not supported for NemotronH's MoE expert layers. The expert weights are 3D `nn.Parameter` tensors (not `nn.Linear`), which the Triton kernel does not support. Keep `lora_mlp_kernel: false`.
|
||||
- **Gradient checkpointing**: Only supported when `sample_packing: true`. Without sample packing the upstream model marks `supports_gradient_checkpointing = False`.
|
||||
@@ -1,74 +0,0 @@
|
||||
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
@@ -31,11 +31,10 @@ lora_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
|
||||
# - gate_up_proj
|
||||
# - down_proj
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target routed experts (3D nn.Parameter tensors, not nn.Linear — use lora_target_parameters):
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
@@ -31,11 +31,11 @@ lora_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
|
||||
# - gate_up_proj
|
||||
# - down_proj
|
||||
|
||||
# Target routed experts (3D nn.Parameter tensors, not nn.Linear — use lora_target_parameters):
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
@@ -31,11 +31,11 @@ lora_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
|
||||
# - gate_up_proj
|
||||
# - down_proj
|
||||
|
||||
# Target routed experts (3D nn.Parameter tensors, not nn.Linear — use lora_target_parameters):
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
@@ -1,18 +1,8 @@
|
||||
base_model: Qwen/Qwen3.5-35B-A3B-Base
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
|
||||
torch_compile: false
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
@@ -23,7 +13,6 @@ datasets:
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
@@ -42,19 +31,15 @@ lora_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
|
||||
# - gate_up_proj
|
||||
# - down_proj
|
||||
|
||||
# Target routed experts (3D nn.Parameter tensors, not nn.Linear — use lora_target_parameters):
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
lora_mlp_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
@@ -62,17 +47,22 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
@@ -59,21 +59,12 @@ lora_target_parameters:
|
||||
|
||||
### Shared Experts (MoE)
|
||||
|
||||
Shared experts use `nn.Linear` (unlike routed experts which are 3D `nn.Parameter` tensors), so they can be targeted via `lora_target_modules`. To also train shared expert projections alongside attention, uncomment `gate_up_proj` and `down_proj` in `lora_target_modules`:
|
||||
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
||||
|
||||
```yaml
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Add gate_up_proj and down_proj to also target shared experts (nn.Linear):
|
||||
# - gate_up_proj
|
||||
# - down_proj
|
||||
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
```
|
||||
|
||||
Use `lora_target_parameters` (see [Routed Experts](#routed-experts-moe) above) to target routed experts separately.
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference hyp, please see the respective model card details.
|
||||
|
||||
@@ -12,7 +12,7 @@ packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.5.0
|
||||
transformers==5.3.0
|
||||
accelerate==1.13.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
@@ -61,12 +61,12 @@ zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.11
|
||||
lm_eval==0.4.7
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.17.0
|
||||
torchao==0.16.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.11.0
|
||||
mistral-common==1.10.0
|
||||
|
||||
2
setup.py
2
setup.py
@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm>=0.17.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpcore
|
||||
import httpx
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
@@ -49,7 +48,7 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except (HTTPError, httpcore.ConnectError, httpx.ConnectError):
|
||||
except (HTTPError, httpcore.ConnectError):
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
|
||||
@@ -4,11 +4,9 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -19,26 +17,12 @@ LOG = get_logger(__name__)
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Merges LoRA adapters with base model using either memory-efficient or legacy approach.
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
along with the LoRA adapters to combine them into a single base model.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
merge_method = str(getattr(cfg, "merge_method", "memory_efficient"))
|
||||
if merge_method == "legacy":
|
||||
LOG.debug("Using legacy LoRA merging method...")
|
||||
_do_merge_lora_legacy(cfg=cfg)
|
||||
else:
|
||||
LOG.debug("Using memory-efficient LoRA merging method...")
|
||||
_do_merge_lora_efficient(cfg=cfg)
|
||||
|
||||
|
||||
def _do_merge_lora_legacy(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Legacy LoRA merging using merge_and_unload.
|
||||
Loads the full model into memory before merging.
|
||||
"""
|
||||
LOG.debug("Using legacy LoRA merging method...")
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
@@ -68,58 +52,6 @@ def _do_merge_lora_legacy(*, cfg: DictDefault) -> None:
|
||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging using shard-by-shard processing.
|
||||
Does not load the full model into memory.
|
||||
|
||||
Supports standard LoRA, RSLoRA, and DoRA. Unsupported methods (AdaLoRA, VeRA)
|
||||
will raise NotImplementedError — use legacy method for those.
|
||||
"""
|
||||
LOG.debug("Using memory-efficient LoRA merging method...")
|
||||
|
||||
output_path = Path(cfg.output_dir) / "merged"
|
||||
safe_tensors = getattr(cfg, "save_safetensors", True)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Detect NF4 quantization from config to simulate QLoRA training dynamics.
|
||||
# Check both current and original (pre-override) config values since do_cli
|
||||
# forces load_in_4bit=False for the legacy path.
|
||||
simulate_nf4 = bool(
|
||||
getattr(cfg, "load_in_4bit", False)
|
||||
or getattr(cfg, "_original_load_in_4bit", False)
|
||||
or getattr(cfg, "adapter", None) == "qlora"
|
||||
or getattr(cfg, "_original_adapter", None) == "qlora"
|
||||
)
|
||||
|
||||
bnb_config_kwargs = getattr(cfg, "bnb_config_kwargs", None) or {}
|
||||
nf4_blocksize = bnb_config_kwargs.get("blocksize", None)
|
||||
nf4_double_quant = bnb_config_kwargs.get(
|
||||
"bnb_4bit_use_double_quant",
|
||||
getattr(cfg, "bnb_4bit_use_double_quant", True),
|
||||
)
|
||||
|
||||
# Detect MoE expert quantization
|
||||
simulate_nf4_experts = bool(
|
||||
getattr(cfg, "quantize_moe_experts", False)
|
||||
or getattr(cfg, "_original_quantize_moe_experts", False)
|
||||
)
|
||||
|
||||
merge_lora_sharded_efficient(
|
||||
base_model_path=cfg.base_model,
|
||||
lora_adapter_path=cfg.lora_model_dir,
|
||||
output_path=output_path,
|
||||
safe_tensors=safe_tensors,
|
||||
device=device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
)
|
||||
|
||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
|
||||
@@ -134,12 +66,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
ValueError: If target directory for LoRA merged model does not exist.
|
||||
"""
|
||||
|
||||
# Pre-load config to detect original quantization settings before overrides
|
||||
raw_cfg = load_cfg(config, **kwargs)
|
||||
original_load_in_4bit = getattr(raw_cfg, "load_in_4bit", False)
|
||||
original_adapter = getattr(raw_cfg, "adapter", None)
|
||||
original_quantize_moe_experts = getattr(raw_cfg, "quantize_moe_experts", False)
|
||||
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
merge_lora=True,
|
||||
@@ -154,16 +80,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Stash original quantization settings for NF4 simulation in efficient merge
|
||||
parsed_cfg._original_load_in_4bit = original_load_in_4bit
|
||||
parsed_cfg._original_adapter = original_adapter
|
||||
parsed_cfg._original_quantize_moe_experts = original_quantize_moe_experts
|
||||
|
||||
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
|
||||
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
|
||||
if not Path(parsed_cfg.lora_model_dir).exists():
|
||||
raise ValueError(
|
||||
f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
|
||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||
)
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
|
||||
@@ -5,7 +5,7 @@ CLI to post-training quantize a model using torchao
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
@@ -93,22 +93,17 @@ def do_quantize(
|
||||
weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
|
||||
ao_config = TorchAoConfig(
|
||||
quant_type=quantization_config,
|
||||
include_input_output_embeddings=quantize_embedding,
|
||||
)
|
||||
model.config.quantization_config = ao_config
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
try:
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
progressbar=True,
|
||||
)
|
||||
except NotImplementedError:
|
||||
LOG.warning(
|
||||
"Model weight conversions do not support reverse_op, "
|
||||
"retrying save with save_original_format=False"
|
||||
)
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
progressbar=True,
|
||||
save_original_format=False,
|
||||
)
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
progressbar=True,
|
||||
|
||||
@@ -84,11 +84,8 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
|
||||
),
|
||||
)
|
||||
|
||||
trainer.fit()
|
||||
return
|
||||
|
||||
do_train(parsed_cfg, parsed_cli_args)
|
||||
return trainer.fit()
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def ray_train_func(kwargs: dict):
|
||||
|
||||
@@ -1,982 +0,0 @@
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _simulate_nf4_roundtrip(
|
||||
tensor: torch.Tensor,
|
||||
blocksize: Optional[int] = None,
|
||||
compress_statistics: bool = True,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simulate NF4 quantization roundtrip to match QLoRA training dynamics.
|
||||
|
||||
During QLoRA training, base weights are quantized to NF4 and dequantized on-the-fly
|
||||
for each forward pass. The LoRA adapters learn to compensate for the quantization
|
||||
noise in the dequantized weights. To match this at merge time, we apply the same
|
||||
quantize → dequantize roundtrip so the merged result reflects what the model saw
|
||||
during training.
|
||||
|
||||
Args:
|
||||
tensor: Base model weight tensor (fp16/bf16/fp32)
|
||||
blocksize: NF4 quantization block size (default: bitsandbytes default)
|
||||
compress_statistics: Whether to use double quantization
|
||||
device: Device for quantization computation. bitsandbytes requires a
|
||||
CUDA device; defaults to "cuda" when available.
|
||||
|
||||
Returns:
|
||||
Tensor after NF4 quantize → dequantize roundtrip, in original dtype
|
||||
"""
|
||||
import bitsandbytes.functional as bnb_F
|
||||
|
||||
quant_device: torch.device
|
||||
if device is None:
|
||||
quant_device = torch.device("cuda")
|
||||
elif isinstance(device, str):
|
||||
quant_device = torch.device(device)
|
||||
else:
|
||||
quant_device = device
|
||||
|
||||
if quant_device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"NF4 simulation requires CUDA but no GPU is available. "
|
||||
"Either run on a machine with a GPU or disable NF4 simulation."
|
||||
)
|
||||
|
||||
original_dtype = tensor.dtype
|
||||
original_shape = tensor.shape
|
||||
|
||||
# bitsandbytes requires float32 input for quantization and contiguous+CUDA tensor
|
||||
flat = tensor.reshape(-1).to(torch.float32).contiguous().to(quant_device)
|
||||
|
||||
quant_kwargs = {
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": compress_statistics,
|
||||
}
|
||||
if blocksize is not None:
|
||||
quant_kwargs["blocksize"] = blocksize
|
||||
|
||||
quantized, quant_state = bnb_F.quantize_4bit(flat, **quant_kwargs)
|
||||
dequantized = bnb_F.dequantize_4bit(quantized, quant_state, quant_type="nf4")
|
||||
|
||||
return dequantized.reshape(original_shape).to(original_dtype).cpu()
|
||||
|
||||
|
||||
def find_lora_weights(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Find corresponding LoRA A and B weights for a given key.
|
||||
|
||||
Also tries keys after applying weight renamings (from transformers v5
|
||||
conversion mappings) in case the checkpoint key names differ from the
|
||||
runtime model key names used by the LoRA adapter.
|
||||
"""
|
||||
import re
|
||||
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
|
||||
# Try the direct key first
|
||||
a_key = f"base_model.model.{clean_key}.lora_A.weight"
|
||||
b_key = f"base_model.model.{clean_key}.lora_B.weight"
|
||||
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
|
||||
if lora_a is not None and lora_b is not None:
|
||||
return lora_a, lora_b
|
||||
|
||||
# Try renamed keys (checkpoint format → runtime format)
|
||||
if weight_renamings:
|
||||
for src_pattern, tgt_pattern in weight_renamings.items():
|
||||
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
|
||||
if renamed_key != clean_key:
|
||||
a_key = f"base_model.model.{renamed_key}.lora_A.weight"
|
||||
b_key = f"base_model.model.{renamed_key}.lora_B.weight"
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
if lora_a is not None and lora_b is not None:
|
||||
return lora_a, lora_b
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _find_param_wrapper_lora(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
tensor_shape: Optional[tuple] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[str]]:
|
||||
"""
|
||||
Find LoRA weights from a ParamWrapper (lora_target_parameters) that targets
|
||||
a parent module containing this weight as a sub-parameter.
|
||||
|
||||
For example, base weight key 'model.layers.0.mlp.experts.down_proj' may have
|
||||
LoRA at 'base_model.model.model.layers.0.mlp.experts.lora_A.weight' (targeting
|
||||
the 'experts' module with 'down_proj' as the parameter_name).
|
||||
|
||||
When tensor_shape is provided, validates that the LoRA dimensions match the
|
||||
target tensor (important when multiple ParamWrappers are nested and each
|
||||
nesting level has different LoRA dimensions).
|
||||
|
||||
Returns (lora_A, lora_B, parameter_name) or (None, None, None).
|
||||
"""
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
# Strip trailing parameter name to get the parent module path
|
||||
# e.g., "model.layers.0.mlp.experts.down_proj" → parent="model.layers.0.mlp.experts", param="down_proj"
|
||||
parts = clean_key.rsplit(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None, None, None
|
||||
|
||||
parent_key, param_name = parts
|
||||
|
||||
# PEFT's ParamWrapper nesting: when multiple parameters are targeted on
|
||||
# the same module, it nests wrappers. The outer wrapper's LoRA is at
|
||||
# parent.lora_A/B and inner wrappers use parent.base_layer.lora_A/B,
|
||||
# parent.base_layer.base_layer.lora_A/B, etc.
|
||||
prefixes_to_try = [
|
||||
f"base_model.model.{parent_key}",
|
||||
]
|
||||
# Walk up .base_layer nesting levels (typically 1-2 deep)
|
||||
for depth in range(1, 4):
|
||||
bl = ".base_layer" * depth
|
||||
prefixes_to_try.append(f"base_model.model.{parent_key}{bl}")
|
||||
|
||||
for prefix in prefixes_to_try:
|
||||
a_key = f"{prefix}.lora_A.weight"
|
||||
b_key = f"{prefix}.lora_B.weight"
|
||||
lora_a = lora_state.get(a_key)
|
||||
lora_b = lora_state.get(b_key)
|
||||
if lora_a is None or lora_b is None:
|
||||
continue
|
||||
|
||||
# When tensor_shape is given, verify dimensions match before returning.
|
||||
# This prevents returning a mismatched LoRA from a different nesting level.
|
||||
if tensor_shape is not None and len(tensor_shape) >= 3:
|
||||
num_experts = tensor_shape[0]
|
||||
if not (
|
||||
lora_a.shape[0] == lora_b.shape[1]
|
||||
and lora_a.shape[0] % num_experts == 0
|
||||
and lora_a.shape[1] == tensor_shape[1]
|
||||
and lora_b.shape[0] == tensor_shape[2]
|
||||
):
|
||||
continue # Dimensions don't match, try next nesting level
|
||||
|
||||
return lora_a, lora_b, param_name
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def _build_peft_layer_and_get_delta(
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
lora_config_dict: Dict,
|
||||
base_tensor: torch.Tensor,
|
||||
adapter_name: str = "default",
|
||||
is_param_wrapper: bool = False,
|
||||
magnitude: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||
|
||||
Instead of re-implementing the merge math for every LoRA variant, this
|
||||
constructs a lightweight PEFT layer, loads the A/B weights, and calls
|
||||
``get_delta_weight`` (or ``merge`` for DoRA) which handles standard LoRA,
|
||||
RSLoRA, DoRA, and ParamWrapper (expert-blocked) LoRA.
|
||||
|
||||
Returns the delta tensor (same shape as base_tensor).
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
r_total = lora_a.shape[0]
|
||||
in_features = lora_a.shape[1]
|
||||
out_features = lora_b.shape[0]
|
||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
||||
|
||||
if is_param_wrapper:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
num_experts = base_tensor.shape[0]
|
||||
r = r_total // num_experts
|
||||
|
||||
class _FakeModule(nn.Module):
|
||||
pass
|
||||
|
||||
fake = _FakeModule()
|
||||
fake.register_parameter(
|
||||
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
layer = ParamWrapper(
|
||||
fake,
|
||||
adapter_name=adapter_name,
|
||||
parameter_name="weight",
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=use_rslora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
else:
|
||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||
|
||||
base_layer = nn.Linear(in_features, out_features, bias=False)
|
||||
base_layer.weight.data = base_tensor.clone()
|
||||
|
||||
fan_in_fan_out = bool(
|
||||
lora_config_dict.get("fan_in_fan_out", False)
|
||||
or lora_config_dict.get("lora_fan_in_fan_out", False)
|
||||
)
|
||||
|
||||
layer = LoraLinear(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
|
||||
if use_dora:
|
||||
# DoRA merges magnitude normalization into the weight directly.
|
||||
# Use PEFT's merge() which handles DoRA internally, then
|
||||
# compute the delta as merged_weight - original_weight.
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
return base_layer.weight.data - base_tensor
|
||||
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
|
||||
|
||||
def get_model_shards(model_path: Path) -> list[Path]:
|
||||
"""Find all model shards in the given path."""
|
||||
shards: list[Path] = []
|
||||
|
||||
patterns = ["model*.safetensors", "pytorch_model*.bin"]
|
||||
|
||||
for pattern in patterns:
|
||||
shards.extend(model_path.glob(pattern))
|
||||
if shards:
|
||||
break
|
||||
|
||||
return sorted(shards)
|
||||
|
||||
|
||||
def copy_non_model_files(
|
||||
input_path: Path, output_path: Path, model_shards: list[Path]
|
||||
) -> None:
|
||||
"""
|
||||
Copy all non-model files to the output directory.
|
||||
|
||||
Args:
|
||||
input_path: Source directory
|
||||
output_path: Destination directory
|
||||
model_shards: List of model shard files to skip
|
||||
"""
|
||||
LOG.info("Copying non-model files to output directory...")
|
||||
|
||||
shard_names = {shard.name for shard in model_shards}
|
||||
|
||||
for filepath in input_path.glob("*"):
|
||||
if filepath.is_dir():
|
||||
continue
|
||||
if filepath.name in shard_names:
|
||||
continue
|
||||
if (
|
||||
filepath.name.startswith("model") and filepath.suffix == ".safetensors"
|
||||
) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"):
|
||||
continue
|
||||
if filepath.suffix == ".gguf":
|
||||
continue
|
||||
# Skip weight-map index files — they reference shard filenames that may
|
||||
# change during the merge (e.g. .bin → .safetensors). A correct index
|
||||
# is regenerated after all shards have been written.
|
||||
if filepath.name.endswith(".index.json"):
|
||||
continue
|
||||
|
||||
LOG.debug(f"Copying {filepath.name} to output")
|
||||
shutil.copy2(filepath, output_path)
|
||||
|
||||
|
||||
def _find_dora_magnitude(
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
key: str,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Find DoRA magnitude vector for a given key.
|
||||
"""
|
||||
import re
|
||||
|
||||
clean_key = key[:-7] if key.endswith(".weight") else key
|
||||
mag_key = f"base_model.model.{clean_key}.lora_magnitude_vector"
|
||||
result = lora_state.get(mag_key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
if weight_renamings:
|
||||
for src_pattern, tgt_pattern in weight_renamings.items():
|
||||
renamed_key = re.sub(src_pattern, tgt_pattern, clean_key)
|
||||
if renamed_key != clean_key:
|
||||
mag_key = f"base_model.model.{renamed_key}.lora_magnitude_vector"
|
||||
result = lora_state.get(mag_key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _should_nf4_roundtrip(
|
||||
key: str,
|
||||
tensor: torch.Tensor,
|
||||
simulate_nf4: bool,
|
||||
simulate_nf4_experts: bool,
|
||||
) -> bool:
|
||||
"""Determine if a tensor should undergo NF4 quantization roundtrip."""
|
||||
if tensor.ndim < 2:
|
||||
return False
|
||||
if simulate_nf4:
|
||||
return True
|
||||
if simulate_nf4_experts and tensor.ndim >= 3 and "expert" in key.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _merge_tensor_with_lora(
|
||||
tensor: torch.Tensor,
|
||||
key: str,
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
scale: float,
|
||||
lora_config_dict: Dict,
|
||||
device: str,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[torch.Tensor, bool]:
|
||||
"""
|
||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||
|
||||
Args:
|
||||
tensor: Base model tensor
|
||||
key: Tensor key/name
|
||||
lora_state: Dictionary containing LoRA weights
|
||||
scale: LoRA scaling factor (alpha/r)
|
||||
lora_config_dict: LoRA configuration dictionary
|
||||
device: Device to perform computations on
|
||||
simulate_nf4: Whether to simulate NF4 quantization roundtrip for all weights
|
||||
simulate_nf4_experts: Whether to simulate NF4 roundtrip for MoE expert tensors only
|
||||
nf4_blocksize: Block size for NF4 quantization
|
||||
nf4_double_quant: Whether to use double quantization
|
||||
use_dora: Whether to apply DoRA (Weight-Decomposed LoRA) merging
|
||||
weight_renamings: Optional key renamings from transformers conversion mapping
|
||||
|
||||
Returns:
|
||||
Tuple of (merged tensor, whether LoRA was applied)
|
||||
"""
|
||||
lora_a, lora_b = find_lora_weights(lora_state, key, weight_renamings)
|
||||
|
||||
do_nf4 = _should_nf4_roundtrip(key, tensor, simulate_nf4, simulate_nf4_experts)
|
||||
|
||||
if lora_a is not None and lora_b is not None:
|
||||
LOG.debug(f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}")
|
||||
|
||||
original_dtype = tensor.dtype
|
||||
|
||||
# Simulate NF4 quantization roundtrip to match QLoRA training dynamics
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
|
||||
magnitude = (
|
||||
_find_dora_magnitude(lora_state, key, weight_renamings)
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
)
|
||||
merged_tensor = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
return merged_tensor, True
|
||||
else:
|
||||
# Try ParamWrapper LoRA (lora_target_parameters) — the LoRA targets a
|
||||
# parent module and this weight is a sub-parameter of that module.
|
||||
if tensor.ndim >= 3:
|
||||
pw_a, pw_b, param_name = _find_param_wrapper_lora(
|
||||
lora_state, key, tensor_shape=tuple(tensor.shape)
|
||||
)
|
||||
if pw_a is not None and pw_b is not None:
|
||||
LOG.debug(
|
||||
f"Merging ParamWrapper LoRA for {key} "
|
||||
f"(param={param_name}): {pw_a.shape}, {pw_b.shape}"
|
||||
)
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
original_dtype = tensor.dtype
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
pw_a.to(device),
|
||||
pw_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
is_param_wrapper=True,
|
||||
)
|
||||
merged = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
return merged, True
|
||||
|
||||
if do_nf4:
|
||||
tensor = _simulate_nf4_roundtrip(
|
||||
tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
return tensor.detach().cpu(), False
|
||||
|
||||
|
||||
def _get_conversion_info(base_model_path: Path) -> tuple[Dict[str, str], list]:
|
||||
"""
|
||||
Load the model's config.json and check if transformers has WeightRenaming
|
||||
or WeightConverter mappings for this model type.
|
||||
|
||||
Returns:
|
||||
- dict of {source_pattern: target_pattern} for simple renamings
|
||||
- list of WeightConverter objects for fuse/unfuse operations
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
config_path = base_model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return {}, []
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
model_config = _json.load(f)
|
||||
except (OSError, _json.JSONDecodeError):
|
||||
return {}, []
|
||||
|
||||
model_type = model_config.get("model_type")
|
||||
if not model_type:
|
||||
return {}, []
|
||||
|
||||
try:
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
from transformers.core_model_loading import WeightConverter, WeightRenaming
|
||||
except ImportError:
|
||||
return {}, []
|
||||
|
||||
conversions = get_checkpoint_conversion_mapping(model_type)
|
||||
if not conversions:
|
||||
return {}, []
|
||||
|
||||
renamings = {}
|
||||
weight_converters = []
|
||||
for conv in conversions:
|
||||
if isinstance(conv, WeightRenaming):
|
||||
# WeightRenaming stores patterns as lists internally
|
||||
src_list = (
|
||||
conv.source_patterns
|
||||
if isinstance(conv.source_patterns, list)
|
||||
else [conv.source_patterns]
|
||||
)
|
||||
tgt_list = (
|
||||
conv.target_patterns
|
||||
if isinstance(conv.target_patterns, list)
|
||||
else [conv.target_patterns]
|
||||
)
|
||||
if len(src_list) == 1 and len(tgt_list) == 1:
|
||||
renamings[src_list[0]] = tgt_list[0]
|
||||
elif isinstance(conv, WeightConverter):
|
||||
weight_converters.append(conv)
|
||||
|
||||
return renamings, weight_converters
|
||||
|
||||
|
||||
def _fuse_and_unfuse_with_merge(
|
||||
shard_tensors: Dict[str, torch.Tensor],
|
||||
weight_converters: list,
|
||||
lora_state: Dict[str, torch.Tensor],
|
||||
scale: float,
|
||||
lora_config_dict: Dict,
|
||||
device: str,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||
"""
|
||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||
1. Fuse checkpoint-format tensors into runtime-format (e.g., per-expert → fused 3D)
|
||||
2. Apply NF4 roundtrip + LoRA merge on the fused tensor
|
||||
3. Unfuse back to checkpoint format for saving
|
||||
|
||||
Returns:
|
||||
- Updated tensor dict
|
||||
- Count of merged LoRA targets
|
||||
- Set of keys that were processed (fused/merged/unfused) and should be
|
||||
skipped by the per-tensor merge pass to avoid double NF4 roundtrip
|
||||
"""
|
||||
import re
|
||||
|
||||
from transformers.core_model_loading import Concatenate, MergeModulelist
|
||||
|
||||
result = dict(shard_tensors) # Start with all tensors
|
||||
merged_count = 0
|
||||
processed_keys: set = set() # Keys that were fuse/unfuse processed
|
||||
|
||||
for converter in weight_converters:
|
||||
src_patterns = (
|
||||
converter.source_patterns
|
||||
if isinstance(converter.source_patterns, list)
|
||||
else [converter.source_patterns]
|
||||
)
|
||||
tgt_patterns = (
|
||||
converter.target_patterns
|
||||
if isinstance(converter.target_patterns, list)
|
||||
else [converter.target_patterns]
|
||||
)
|
||||
|
||||
# Build regex for each source pattern
|
||||
pattern_regexes = []
|
||||
for pat in src_patterns:
|
||||
regex_str = re.escape(pat).replace(r"\.\*\.", r"\.(\d+)\.")
|
||||
regex_str = (
|
||||
regex_str.rstrip(r"\$") if regex_str.endswith(r"\$") else regex_str
|
||||
)
|
||||
pattern_regexes.append(re.compile(r"(.*\.)?" + regex_str + "$"))
|
||||
|
||||
# Group matching keys by layer prefix and source pattern
|
||||
# {layer_prefix: {pat_idx: {expert_idx: (key, tensor)}}}
|
||||
layer_groups: Dict[str, Dict[int, Dict[int, tuple[str, torch.Tensor]]]] = {}
|
||||
|
||||
for key in list(result.keys()):
|
||||
for pat_idx, pat_regex in enumerate(pattern_regexes):
|
||||
match = pat_regex.match(key)
|
||||
if match:
|
||||
prefix = match.group(1) or ""
|
||||
# Extract expert index from the matched portion
|
||||
remaining = key[len(prefix) :]
|
||||
expert_match = re.search(r"\.(\d+)\.", remaining)
|
||||
expert_idx = int(expert_match.group(1)) if expert_match else 0
|
||||
|
||||
layer_groups.setdefault(prefix, {}).setdefault(pat_idx, {})[
|
||||
expert_idx
|
||||
] = (key, result[key])
|
||||
break
|
||||
|
||||
# Process each layer group
|
||||
for prefix, pat_groups in layer_groups.items():
|
||||
# Check we have all source patterns for this layer
|
||||
if not pat_groups:
|
||||
continue
|
||||
|
||||
# Step 1: Fuse — MergeModulelist (stack experts) per source pattern
|
||||
fused_per_pattern = {}
|
||||
original_keys_per_pattern: Dict[int, list[str]] = {}
|
||||
num_experts = None
|
||||
|
||||
for pat_idx in sorted(pat_groups.keys()):
|
||||
expert_data = pat_groups[pat_idx]
|
||||
sorted_indices = sorted(expert_data.keys())
|
||||
if num_experts is None:
|
||||
num_experts = len(sorted_indices)
|
||||
|
||||
sorted_tensors = [expert_data[idx][1] for idx in sorted_indices]
|
||||
original_keys_per_pattern[pat_idx] = [
|
||||
expert_data[idx][0] for idx in sorted_indices
|
||||
]
|
||||
fused_per_pattern[src_patterns[pat_idx]] = torch.stack(
|
||||
sorted_tensors, dim=0
|
||||
)
|
||||
|
||||
# Apply remaining operations (Concatenate)
|
||||
fused_tensor = None
|
||||
has_concat = False
|
||||
concat_dim = 1 # default
|
||||
|
||||
for op in converter.operations:
|
||||
if isinstance(op, MergeModulelist):
|
||||
pass # Already handled
|
||||
elif isinstance(op, Concatenate):
|
||||
has_concat = True
|
||||
concat_dim = op.dim
|
||||
tensors_to_cat = [
|
||||
fused_per_pattern[sp]
|
||||
for sp in src_patterns
|
||||
if sp in fused_per_pattern
|
||||
]
|
||||
if len(tensors_to_cat) > 1:
|
||||
fused_tensor = torch.cat(tensors_to_cat, dim=concat_dim)
|
||||
elif tensors_to_cat:
|
||||
fused_tensor = tensors_to_cat[0]
|
||||
|
||||
if not has_concat and len(fused_per_pattern) == 1:
|
||||
fused_tensor = next(iter(fused_per_pattern.values()))
|
||||
|
||||
if fused_tensor is None:
|
||||
continue
|
||||
|
||||
# Step 2: Build the fused key name and merge LoRA
|
||||
fused_key = prefix + tgt_patterns[0]
|
||||
|
||||
# Apply NF4 roundtrip on the fused tensor (matching training dynamics)
|
||||
do_nf4 = _should_nf4_roundtrip(
|
||||
fused_key, fused_tensor, simulate_nf4, simulate_nf4_experts
|
||||
)
|
||||
if do_nf4:
|
||||
fused_tensor = _simulate_nf4_roundtrip(
|
||||
fused_tensor,
|
||||
blocksize=nf4_blocksize,
|
||||
compress_statistics=nf4_double_quant,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Try to find and merge LoRA weights for the fused key
|
||||
lora_a, lora_b = find_lora_weights(lora_state, fused_key, weight_renamings)
|
||||
if lora_a is not None and lora_b is not None:
|
||||
LOG.debug(
|
||||
f"Merging LoRA for fused key {fused_key}: {lora_a.shape}, {lora_b.shape}"
|
||||
)
|
||||
original_dtype = fused_tensor.dtype
|
||||
magnitude = (
|
||||
_find_dora_magnitude(lora_state, fused_key, weight_renamings)
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
fused_tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
)
|
||||
fused_tensor = (
|
||||
(
|
||||
fused_tensor.to(device).to(torch.float32)
|
||||
+ delta.to(torch.float32)
|
||||
)
|
||||
.to(original_dtype)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
merged_count += 1
|
||||
|
||||
# Step 3: Save in fused format (runtime format) so that the merged
|
||||
# model can be loaded directly without needing WeightConverter
|
||||
# fusion during from_pretrained (which can OOM for large MoE models).
|
||||
# Remove the original per-expert keys and save the fused tensor
|
||||
# under the runtime key name.
|
||||
for pat_idx in sorted(original_keys_per_pattern.keys()):
|
||||
for ok in original_keys_per_pattern[pat_idx]:
|
||||
result.pop(ok, None)
|
||||
processed_keys.add(ok)
|
||||
|
||||
result[fused_key] = fused_tensor.detach().cpu()
|
||||
processed_keys.add(fused_key)
|
||||
|
||||
return result, merged_count, processed_keys
|
||||
|
||||
|
||||
def merge_lora_sharded_efficient(
|
||||
base_model_path: Union[str, Path],
|
||||
lora_adapter_path: Union[str, Path],
|
||||
output_path: Union[str, Path],
|
||||
device: str = "cpu",
|
||||
safe_tensors: bool = True,
|
||||
simulate_nf4: bool = False,
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging that processes shards individually
|
||||
without loading the full model into memory.
|
||||
|
||||
Args:
|
||||
simulate_nf4: Apply NF4 roundtrip to ALL weight tensors (for QLoRA)
|
||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||
(for quantize_moe_experts). Expert tensors are identified by having
|
||||
"expert" in the key name and ndim >= 3.
|
||||
"""
|
||||
base_model_path = Path(base_model_path)
|
||||
lora_adapter_path = Path(lora_adapter_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
if "/" in str(base_model_path) and not base_model_path.exists():
|
||||
base_model_path = Path(snapshot_download(str(base_model_path)))
|
||||
|
||||
# Check for weight conversion requirements (transformers v5)
|
||||
weight_renamings, weight_converters = _get_conversion_info(base_model_path)
|
||||
if weight_renamings:
|
||||
LOG.debug(f"Found {len(weight_renamings)} weight renamings for this model type")
|
||||
if weight_converters:
|
||||
LOG.debug(
|
||||
f"Found {len(weight_converters)} weight converters (fuse/unfuse) for this model type. "
|
||||
f"Will fuse→merge→unfuse within each shard."
|
||||
)
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
config_file = lora_adapter_path / "adapter_config.json"
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"LoRA config not found: {config_file}")
|
||||
|
||||
lora_config_dict = LoraConfig.from_json_file(str(config_file))
|
||||
if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0:
|
||||
raise ValueError("LoRA config 'r' must be > 0")
|
||||
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
unsupported_methods = []
|
||||
|
||||
# Check for AdaLoRA (Adaptive LoRA)
|
||||
if lora_config_dict.get("use_adalora", False):
|
||||
unsupported_methods.append("AdaLoRA (Adaptive LoRA)")
|
||||
|
||||
# Check for VeRA (Vector-based Random Matrix Adaptation)
|
||||
if lora_config_dict.get("use_vera", False):
|
||||
unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)")
|
||||
|
||||
# Check for other advanced LoRA variants by task_type
|
||||
task_type = lora_config_dict.get("task_type", "")
|
||||
if task_type and task_type not in [
|
||||
"CAUSAL_LM",
|
||||
"SEQ_2_SEQ_LM",
|
||||
"TOKEN_CLS",
|
||||
"SEQ_CLS",
|
||||
"QUESTION_ANS",
|
||||
]:
|
||||
unsupported_methods.append(f"Task type: {task_type}")
|
||||
|
||||
# Check for rank adaptation patterns (AdaLoRA indicators)
|
||||
# Use .get() so empty dicts/None don't false-positive
|
||||
if any(
|
||||
lora_config_dict.get(key)
|
||||
for key in ["rank_pattern", "alpha_pattern", "target_rank"]
|
||||
):
|
||||
unsupported_methods.append("AdaLoRA (rank adaptation detected)")
|
||||
|
||||
# Check for advanced initialization methods
|
||||
init_lora_weights = lora_config_dict.get("init_lora_weights", "")
|
||||
if init_lora_weights and init_lora_weights not in [
|
||||
"gaussian",
|
||||
"loftq",
|
||||
True,
|
||||
False,
|
||||
]:
|
||||
unsupported_methods.append(f"Advanced initialization: {init_lora_weights}")
|
||||
|
||||
if unsupported_methods:
|
||||
methods_str = ", ".join(unsupported_methods)
|
||||
raise NotImplementedError(
|
||||
f"Memory-efficient LoRA merge only supports standard LoRA. "
|
||||
f"Detected unsupported methods: {methods_str}. "
|
||||
f"Please use the legacy merge method for advanced LoRA variants."
|
||||
)
|
||||
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
if use_rslora:
|
||||
scale = float(lora_config_dict["lora_alpha"]) / math.sqrt(
|
||||
float(lora_config_dict["r"])
|
||||
)
|
||||
else:
|
||||
scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"])
|
||||
|
||||
LOG.debug(f"LoRA scale factor: {scale} (rslora={use_rslora})")
|
||||
|
||||
if simulate_nf4:
|
||||
LOG.info(
|
||||
"NF4 simulation enabled: base weights will undergo quantize→dequantize "
|
||||
"roundtrip before LoRA merge to match QLoRA training dynamics"
|
||||
)
|
||||
|
||||
lora_file = lora_adapter_path / "adapter_model.safetensors"
|
||||
if not lora_file.exists():
|
||||
lora_file = lora_adapter_path / "adapter_model.bin"
|
||||
if not lora_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"LoRA adapter weights not found in {lora_adapter_path}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Loading LoRA weights from {lora_file}")
|
||||
|
||||
if lora_file.suffix == ".safetensors":
|
||||
lora_state = safetensors.torch.load_file(lora_file)
|
||||
else:
|
||||
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614
|
||||
LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
|
||||
|
||||
model_shards = get_model_shards(base_model_path)
|
||||
if not model_shards:
|
||||
raise FileNotFoundError(f"No model shards found in {base_model_path}")
|
||||
|
||||
LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
|
||||
copy_non_model_files(base_model_path, output_path, model_shards)
|
||||
|
||||
merged_count = 0
|
||||
total_tensors = 0
|
||||
# Track weight_map for index regeneration: {tensor_key: shard_filename}
|
||||
weight_map: Dict[str, str] = {}
|
||||
|
||||
for shard_path in tqdm(model_shards, desc="Merging shards"):
|
||||
merged_tensors = {}
|
||||
metadata = {}
|
||||
|
||||
# Load all tensors from the shard
|
||||
if shard_path.suffix == ".safetensors":
|
||||
with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
|
||||
if hasattr(f, "metadata") and f.metadata():
|
||||
metadata = f.metadata()
|
||||
shard_tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
shard_tensors = torch.load( # nosec B614: loading trusted model weights
|
||||
shard_path, map_location="cpu", weights_only=True
|
||||
)
|
||||
|
||||
total_tensors += len(shard_tensors)
|
||||
|
||||
# Step 1: Handle fused weight conversions (MoE experts) if applicable
|
||||
fused_keys: set = set()
|
||||
if weight_converters:
|
||||
shard_tensors, fused_merged, fused_keys = _fuse_and_unfuse_with_merge(
|
||||
shard_tensors,
|
||||
weight_converters,
|
||||
lora_state,
|
||||
scale,
|
||||
lora_config_dict,
|
||||
device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
)
|
||||
merged_count += fused_merged
|
||||
|
||||
# Step 2: Merge remaining (non-fused) tensors with LoRA
|
||||
# Skip keys already processed by fuse/unfuse to avoid double NF4 roundtrip
|
||||
for key, tensor in shard_tensors.items():
|
||||
if key in fused_keys:
|
||||
merged_tensors[key] = tensor.detach().cpu()
|
||||
continue
|
||||
merged_tensor, was_merged = _merge_tensor_with_lora(
|
||||
tensor,
|
||||
key,
|
||||
lora_state,
|
||||
scale,
|
||||
lora_config_dict,
|
||||
device,
|
||||
simulate_nf4=simulate_nf4,
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
)
|
||||
merged_tensors[key] = merged_tensor
|
||||
if was_merged:
|
||||
merged_count += 1
|
||||
|
||||
output_shard_path = output_path / shard_path.name
|
||||
merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
|
||||
|
||||
if safe_tensors:
|
||||
if not str(output_shard_path).endswith(".safetensors"):
|
||||
output_shard_path = output_path / (shard_path.stem + ".safetensors")
|
||||
safetensors.torch.save_file(
|
||||
merged_tensors, output_shard_path, metadata=metadata
|
||||
)
|
||||
else:
|
||||
if shard_path.suffix == ".safetensors":
|
||||
safetensors.torch.save_file(
|
||||
merged_tensors, output_shard_path, metadata=metadata
|
||||
)
|
||||
else:
|
||||
torch.save(merged_tensors, output_shard_path)
|
||||
|
||||
for tensor_key in merged_tensors:
|
||||
weight_map[tensor_key] = output_shard_path.name
|
||||
|
||||
del merged_tensors, shard_tensors
|
||||
if device != "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Regenerate weight-map index if the model was sharded
|
||||
if len(model_shards) > 1 and weight_map:
|
||||
import json as _json
|
||||
|
||||
index_name = (
|
||||
"model.safetensors.index.json"
|
||||
if safe_tensors
|
||||
else "pytorch_model.bin.index.json"
|
||||
)
|
||||
index = {
|
||||
"metadata": {"total_size": total_tensors},
|
||||
"weight_map": weight_map,
|
||||
}
|
||||
with open(output_path / index_name, "w") as f:
|
||||
_json.dump(index, f, indent=2)
|
||||
LOG.debug(f"Wrote weight-map index: {index_name}")
|
||||
|
||||
if merged_count == 0:
|
||||
LOG.warning(
|
||||
"No LoRA weights were matched to base model tensors. "
|
||||
"This may indicate a key name mismatch between the checkpoint format "
|
||||
"and the LoRA adapter. Consider using merge_method: legacy."
|
||||
)
|
||||
LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors")
|
||||
@@ -38,14 +38,18 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
# Determine serve module: explicit CLI/config > default (axolotl's LoRA-aware serve).
|
||||
# We default to axolotl's serve module instead of TRL's because TRL's sends
|
||||
# truncate_prompt_tokens which is unsupported in vLLM 0.17+.
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if serve_module is None:
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -75,12 +79,6 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
cli_enforce_eager = cli_args.get("enforce_eager")
|
||||
cfg_enforce_eager = getattr(cfg.vllm, "enforce_eager", None)
|
||||
raw_enforce_eager = (
|
||||
cfg_enforce_eager if cli_enforce_eager is None else cli_enforce_eager
|
||||
)
|
||||
enforce_eager = bool(raw_enforce_eager) if raw_enforce_eager is not None else False
|
||||
base_kwargs = dict(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
@@ -91,7 +89,6 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
@@ -101,12 +98,6 @@ def do_vllm_serve(
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
# Disable native LoRA in vLLM if not using vllm_lora_sync
|
||||
# (merged weight sync via batch_update doesn't need vLLM LoRA mode)
|
||||
if not getattr(cfg.trl, "vllm_lora_sync", False):
|
||||
lora_kwargs["enable_lora"] = False
|
||||
if getattr(cfg.vllm, "worker_extension_cls", None):
|
||||
lora_kwargs["worker_extension_cls"] = cfg.vllm.worker_extension_cls
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
|
||||
@@ -23,5 +23,4 @@ MOE_ARCH_BLOCK = {
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ def load_preference_datasets(
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl not in {RLType.GRPO, RLType.EBFT}:
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
@@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.optim.adam import AdamWFp8
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
optimizer_cls = AdamWFp8
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
@@ -78,11 +78,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
trainer_cls = EBFTStrategy.get_trainer_class(self.cfg)
|
||||
trainer_kwargs.update(EBFTStrategy.set_trainer_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -127,6 +122,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
@@ -173,22 +171,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if not async_grpo:
|
||||
# Filter out async/fast-async-only fields not in standard GRPOConfig.
|
||||
# These are defined in FastAsyncGRPOConfig and only used by
|
||||
# AxolotlAsyncGRPOConfig. Standard GRPOConfig rejects them.
|
||||
import dataclasses
|
||||
|
||||
from trl import GRPOConfig as _BaseGRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import (
|
||||
FastAsyncGRPOConfig,
|
||||
)
|
||||
|
||||
async_only_fields = {
|
||||
f.name for f in dataclasses.fields(FastAsyncGRPOConfig)
|
||||
} - {f.name for f in dataclasses.fields(_BaseGRPOConfig)}
|
||||
blocklist_args_kwargs.extend(list(async_only_fields))
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
training_args_kwargs.setdefault(
|
||||
"multi_objective_aggregation", "normalize_then_sum"
|
||||
@@ -197,13 +179,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl is RLType.EBFT:
|
||||
from axolotl.core.trainers.ebft import EBFTStrategy
|
||||
|
||||
training_args_cls = EBFTStrategy.get_training_args_class(self.cfg)
|
||||
training_args_kwargs.update(EBFTStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = EBFTStrategy.get_blocklist_args_kwargs(self.cfg)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -236,9 +211,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
|
||||
|
||||
@@ -2,34 +2,13 @@
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.utils import make_lazy_getattr
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
__all__ = [
|
||||
"AxolotlTrainer",
|
||||
"AxolotlCPOTrainer",
|
||||
"AxolotlDPOTrainer",
|
||||
"AxolotlEBFTTrainer",
|
||||
"AxolotlKTOTrainer",
|
||||
"AxolotlMambaTrainer",
|
||||
"AxolotlORPOTrainer",
|
||||
"AxolotlPRMTrainer",
|
||||
"AxolotlRewardTrainer",
|
||||
"AxolotlStridedEBFTTrainer",
|
||||
]
|
||||
|
||||
_LAZY_IMPORTS = {
|
||||
"AxolotlDPOTrainer": ".dpo.trainer",
|
||||
"AxolotlStridedEBFTTrainer": ".ebft.strided",
|
||||
"AxolotlEBFTTrainer": ".ebft.trainer",
|
||||
"AxolotlMambaTrainer": ".mamba",
|
||||
"AxolotlCPOTrainer": ".trl",
|
||||
"AxolotlKTOTrainer": ".trl",
|
||||
"AxolotlORPOTrainer": ".trl",
|
||||
"AxolotlPRMTrainer": ".trl",
|
||||
"AxolotlRewardTrainer": ".trl",
|
||||
}
|
||||
|
||||
__getattr__ = make_lazy_getattr(_LAZY_IMPORTS, __name__, globals())
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
)
|
||||
|
||||
@@ -381,15 +381,6 @@ class AxolotlTrainer(
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||
# Inject zeros (= text token type) when not provided by the data collator.
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -414,13 +405,15 @@ class AxolotlTrainer(
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
max_length = max(inputs["input_ids"].shape[1], inputs["rejected_ids"].shape[1])
|
||||
max_length = max(
|
||||
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
||||
)
|
||||
# Concatenate positive and negative inputs
|
||||
concatenated_batch["input_ids"] = pad_to_length(
|
||||
inputs["input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["rejected_ids"] = pad_to_length(
|
||||
inputs["rejected_ids"], max_length, pad_token
|
||||
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
||||
inputs["rejected_input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["labels"] = pad_to_length(
|
||||
inputs["labels"], max_length, label_pad_token
|
||||
@@ -439,7 +432,7 @@ class AxolotlTrainer(
|
||||
).to(device=device)
|
||||
|
||||
input_ids = torch.cat(
|
||||
[concatenated_batch["input_ids"], concatenated_batch["rejected_ids"]],
|
||||
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
attention_mask = torch.cat(
|
||||
@@ -517,24 +510,12 @@ class AxolotlTrainer(
|
||||
)
|
||||
|
||||
# Perform a single forward pass
|
||||
forward_kwargs = {
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
}
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only)
|
||||
if (
|
||||
getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
and "mm_token_type_ids" not in concat_inputs
|
||||
):
|
||||
forward_kwargs["mm_token_type_ids"] = torch.zeros_like(
|
||||
concat_inputs["input_ids"]
|
||||
)
|
||||
elif "mm_token_type_ids" in concat_inputs:
|
||||
forward_kwargs["mm_token_type_ids"] = concat_inputs["mm_token_type_ids"]
|
||||
|
||||
outputs = model(
|
||||
**forward_kwargs,
|
||||
**{
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class DPOStrategy:
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = ["ipo"]
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
@@ -30,10 +30,8 @@ class DPOStrategy:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
if cfg.precompute_ref_log_probs is not None:
|
||||
training_args_kwargs["precompute_ref_log_probs"] = (
|
||||
cfg.precompute_ref_log_probs
|
||||
)
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -14,3 +15,6 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
rpo_alpha: Optional[float] = field(default=None)
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
@@ -19,7 +18,6 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils.data.utils import remove_double_bos_token
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(
|
||||
@@ -55,31 +53,36 @@ class AxolotlDPOTrainer(
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def _tokenize(
|
||||
self,
|
||||
processing_class: PreTrainedTokenizerBase | ProcessorMixin,
|
||||
input: str | list,
|
||||
**kwargs,
|
||||
) -> dict[str, list]:
|
||||
"""
|
||||
Override TRL's tokenization in DPO trainer to fix double bos_token bug (eg. llama).
|
||||
"""
|
||||
result = super()._tokenize(
|
||||
processing_class=processing_class, input=input, **kwargs
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
is_chat: bool = False,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length=max_prompt_length,
|
||||
max_completion_length=max_completion_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_chat=is_chat,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
# Handle multimodal models
|
||||
tokenizer = (
|
||||
getattr(processing_class, "tokenizer", None)
|
||||
if isinstance(processing_class, ProcessorMixin)
|
||||
else processing_class
|
||||
)
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
|
||||
bos_token_id = getattr(tokenizer, "bos_token_id", None) if tokenizer else None
|
||||
if bos_token_id is not None:
|
||||
result = remove_double_bos_token(result, bos_token_id)
|
||||
|
||||
return result
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
@@ -91,3 +94,20 @@ class AxolotlDPOTrainer(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: nn.Module,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
is_ref_model: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: list[str] = self.loss_type # type: ignore[has-type]
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = ["ipo"]
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type
|
||||
return res
|
||||
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
"""EBFT (Energy-Based Fine-Tuning) Strategy for training
|
||||
|
||||
Two modes:
|
||||
- structured: For QA data with prompt/completion splits. Uses GRPOTrainer + vLLM.
|
||||
- strided: For unstructured text (raw code, prose). Uses strided block-parallel generation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from axolotl.core.trainers.ebft.args import (
|
||||
AxolotlAsyncEBFTConfig,
|
||||
AxolotlEBFTConfig,
|
||||
AxolotlStridedEBFTConfig,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _get_ebft_mode(cfg: DictDefault) -> str:
|
||||
"""Determine EBFT mode from config."""
|
||||
if cfg.ebft and hasattr(cfg.ebft, "mode") and cfg.ebft.mode:
|
||||
return cfg.ebft.mode
|
||||
return "structured"
|
||||
|
||||
|
||||
class EBFTStrategy:
|
||||
"""Strategy for EBFT training — dispatches between structured and strided modes."""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
from axolotl.core.trainers.ebft.strided import AxolotlStridedEBFTTrainer
|
||||
|
||||
return AxolotlStridedEBFTTrainer
|
||||
|
||||
# Structured mode: async or sync
|
||||
# use_data_producer also triggers async trainer (needed for LoRA sync
|
||||
# without async_prefetch, since sync trainer lacks LoRA sync support)
|
||||
use_async = (
|
||||
cfg
|
||||
and cfg.trl
|
||||
and (
|
||||
getattr(cfg.trl, "async_prefetch", False)
|
||||
or getattr(cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
if use_async:
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlAsyncEBFTTrainer
|
||||
|
||||
return AxolotlAsyncEBFTTrainer
|
||||
from axolotl.core.trainers.ebft.trainer import AxolotlEBFTTrainer
|
||||
|
||||
return AxolotlEBFTTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls, cfg: DictDefault | None = None):
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return AxolotlStridedEBFTConfig
|
||||
|
||||
# Structured mode: async or sync config
|
||||
use_async = (
|
||||
cfg
|
||||
and cfg.trl
|
||||
and (
|
||||
getattr(cfg.trl, "async_prefetch", False)
|
||||
or getattr(cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
if use_async:
|
||||
return AxolotlAsyncEBFTConfig
|
||||
return AxolotlEBFTConfig
|
||||
|
||||
@classmethod
|
||||
def is_strided(cls, cfg: DictDefault) -> bool:
|
||||
return _get_ebft_mode(cfg) == "strided"
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
"""Map axolotl YAML config fields to training args kwargs."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
mode = _get_ebft_mode(cfg)
|
||||
|
||||
# Common EBFT fields
|
||||
ebft = cfg.ebft
|
||||
if ebft:
|
||||
if ebft.feature_layers is not None:
|
||||
kwargs["ebft_feature_layers"] = ebft.feature_layers
|
||||
if ebft.embed_method is not None:
|
||||
kwargs["ebft_embed_method"] = ebft.embed_method
|
||||
if ebft.use_whitening is not None:
|
||||
kwargs["ebft_use_whitening"] = ebft.use_whitening
|
||||
if ebft.alignment_coef is not None:
|
||||
kwargs["ebft_alignment_coef"] = ebft.alignment_coef
|
||||
if ebft.diversity_coef is not None:
|
||||
kwargs["ebft_diversity_coef"] = ebft.diversity_coef
|
||||
if ebft.ce_coef is not None:
|
||||
kwargs["ebft_ce_coef"] = ebft.ce_coef
|
||||
if getattr(ebft, "adaptive_max_tokens", None) is not None:
|
||||
kwargs["ebft_adaptive_max_tokens"] = ebft.adaptive_max_tokens
|
||||
if getattr(ebft, "gt_length_multiplier", None) is not None:
|
||||
kwargs["ebft_gt_length_multiplier"] = ebft.gt_length_multiplier
|
||||
|
||||
if mode == "strided":
|
||||
# Strided-specific fields
|
||||
if ebft:
|
||||
if ebft.stride is not None:
|
||||
kwargs["ebft_stride"] = ebft.stride
|
||||
if ebft.context_length is not None:
|
||||
kwargs["ebft_context_length"] = ebft.context_length
|
||||
if ebft.generate_max_len is not None:
|
||||
kwargs["ebft_generate_max_len"] = ebft.generate_max_len
|
||||
if ebft.n_samples_per_prompt is not None:
|
||||
kwargs["ebft_n_samples_per_prompt"] = ebft.n_samples_per_prompt
|
||||
if ebft.temperature is not None:
|
||||
kwargs["ebft_temperature"] = ebft.temperature
|
||||
if ebft.top_p is not None:
|
||||
kwargs["ebft_top_p"] = ebft.top_p
|
||||
if ebft.rl_coef is not None:
|
||||
kwargs["ebft_rl_coef"] = ebft.rl_coef
|
||||
if ebft.advantage_estimator is not None:
|
||||
kwargs["ebft_advantage_estimator"] = ebft.advantage_estimator
|
||||
if ebft.min_completion_prefix is not None:
|
||||
kwargs["ebft_min_completion_prefix"] = ebft.min_completion_prefix
|
||||
else:
|
||||
# Structured mode: map TRL config fields
|
||||
trl = cfg.trl
|
||||
if trl:
|
||||
if trl.use_vllm:
|
||||
kwargs["use_vllm"] = trl.use_vllm
|
||||
if trl.vllm_mode:
|
||||
kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode
|
||||
vllm_cfg = cfg.vllm
|
||||
if vllm_cfg:
|
||||
kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
kwargs["vllm_server_host"] = trl.vllm_server_host or (
|
||||
trl.vllm.host if trl.vllm else None
|
||||
)
|
||||
kwargs["vllm_server_port"] = trl.vllm_server_port or (
|
||||
trl.vllm.port if trl.vllm else None
|
||||
)
|
||||
if trl.vllm_server_timeout:
|
||||
kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
|
||||
if trl.num_generations:
|
||||
kwargs["num_generations"] = trl.num_generations
|
||||
if trl.max_completion_length is not None:
|
||||
kwargs["max_completion_length"] = trl.max_completion_length
|
||||
if trl.temperature is not None:
|
||||
kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
kwargs["min_p"] = trl.min_p
|
||||
if trl.num_iterations is not None:
|
||||
kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
kwargs["epsilon"] = trl.epsilon
|
||||
if trl.epsilon_high is not None:
|
||||
kwargs["epsilon_high"] = trl.epsilon_high
|
||||
if trl.scale_rewards is not None:
|
||||
kwargs["scale_rewards"] = trl.scale_rewards
|
||||
if trl.loss_type is not None:
|
||||
kwargs["loss_type"] = trl.loss_type
|
||||
if trl.mask_truncated_completions is not None:
|
||||
kwargs["mask_truncated_completions"] = (
|
||||
trl.mask_truncated_completions
|
||||
)
|
||||
if trl.log_completions is not None:
|
||||
kwargs["log_completions"] = trl.log_completions
|
||||
if trl.num_completions_to_print is not None:
|
||||
kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
if trl.sync_ref_model:
|
||||
kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
if trl.repetition_penalty is not None:
|
||||
kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
if trl.generation_kwargs is not None:
|
||||
kwargs["generation_kwargs"] = trl.generation_kwargs
|
||||
if trl.chat_template_kwargs is not None:
|
||||
kwargs["chat_template_kwargs"] = trl.chat_template_kwargs
|
||||
|
||||
# Async prefetch fields (only pass when enabled — sync config doesn't have these)
|
||||
if getattr(trl, "async_prefetch", False):
|
||||
kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "vllm_lora_sync", False):
|
||||
kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls, cfg: DictDefault | None = None) -> list[str]:
|
||||
mode = _get_ebft_mode(cfg) if cfg else "structured"
|
||||
if mode == "strided":
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"max_prompt_length",
|
||||
"include_tokens_per_second",
|
||||
"beta",
|
||||
]
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"include_tokens_per_second",
|
||||
"max_prompt_length",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs):
|
||||
return None
|
||||
@@ -1,133 +0,0 @@
|
||||
"""
|
||||
EBFT-specific training arguments.
|
||||
|
||||
Two config classes:
|
||||
- AxolotlEBFTConfig: extends GRPOConfig for structured QA data (uses vLLM generation)
|
||||
- AxolotlStridedEBFTConfig: extends TrainingArguments for unstructured text (strided generation)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
# -- Shared EBFT fields as a mixin --
|
||||
@dataclass
|
||||
class EBFTFieldsMixin:
|
||||
"""Common fields shared between structured and strided EBFT configs."""
|
||||
|
||||
ebft_feature_layers: list[float] = field(
|
||||
default_factory=lambda: [0.25, 0.5, 0.75],
|
||||
metadata={"help": "Fractional layer depths for feature extraction"},
|
||||
)
|
||||
ebft_embed_method: str = field(
|
||||
default="last_token",
|
||||
metadata={"help": "Pooling method: 'last_token', 'mean_pooling', or 'concat'"},
|
||||
)
|
||||
ebft_use_whitening: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Apply SVD whitening to feature embeddings"},
|
||||
)
|
||||
ebft_alignment_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for alignment reward (cosine similarity)"},
|
||||
)
|
||||
ebft_diversity_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Coefficient for diversity penalty"},
|
||||
)
|
||||
ebft_ce_coef: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Cross-entropy loss coefficient on ground-truth tokens"},
|
||||
)
|
||||
ebft_adaptive_max_tokens: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Set per-batch max_tokens based on ground-truth length"},
|
||||
)
|
||||
ebft_gt_length_multiplier: float = field(
|
||||
default=1.5,
|
||||
metadata={
|
||||
"help": "Multiplier for ground-truth token count when computing adaptive max_tokens"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Structured mode: extends GRPOTrainer for QA data with vLLM --
|
||||
@dataclass
|
||||
class AxolotlEBFTConfig(EBFTFieldsMixin, AxolotlTrainingMixins, GRPOConfig):
|
||||
"""EBFT config for structured QA data — extends GRPOConfig."""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Async structured mode: extends FastAsyncGRPOConfig --
|
||||
@dataclass
|
||||
class AxolotlAsyncEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, FastAsyncGRPOConfig
|
||||
):
|
||||
"""EBFT config for async structured QA data — extends FastAsyncGRPOConfig.
|
||||
|
||||
Includes all async fields: async_prefetch, vllm_lora_sync,
|
||||
skip_zero_advantage_batches, streaming_partial_batch, replay_buffer_size, etc.
|
||||
"""
|
||||
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Sync LoRA adapters to vLLM via filesystem instead of NCCL weight merge."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- Strided mode: extends TrainingArguments for unstructured text --
|
||||
@dataclass
|
||||
class AxolotlStridedEBFTConfig(
|
||||
EBFTFieldsMixin, AxolotlTrainingMixins, TrainingArguments
|
||||
):
|
||||
"""EBFT config for unstructured text with strided block-parallel generation."""
|
||||
|
||||
ebft_stride: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Stride between anchor points (in tokens)"},
|
||||
)
|
||||
ebft_context_length: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Context window size for each block"},
|
||||
)
|
||||
ebft_generate_max_len: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of tokens to generate per block"},
|
||||
)
|
||||
ebft_n_samples_per_prompt: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Number of independent rollouts per document"},
|
||||
)
|
||||
ebft_temperature: float = field(
|
||||
default=0.6,
|
||||
metadata={"help": "Sampling temperature for strided generation"},
|
||||
)
|
||||
ebft_top_p: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Top-p nucleus sampling threshold"},
|
||||
)
|
||||
ebft_rl_coef: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "RL policy gradient loss coefficient"},
|
||||
)
|
||||
ebft_advantage_estimator: str = field(
|
||||
default="rloo",
|
||||
metadata={"help": "Advantage estimator: 'rloo', 'group_norm', or 'reinforce'"},
|
||||
)
|
||||
ebft_min_completion_prefix: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Minimum tokens into completion before placing anchors"},
|
||||
)
|
||||
@@ -1,308 +0,0 @@
|
||||
"""
|
||||
Fused Triton kernels for strided EBFT.
|
||||
|
||||
These kernels eliminate intermediate tensor materializations that dominate
|
||||
the elementwise/fill category (~40% of CUDA time in profiling).
|
||||
|
||||
Kernels:
|
||||
1. fused_log_softmax_gather: log_softmax + gather in one pass (no full vocab materialization)
|
||||
2. fused_masked_reinforce_loss: -logp * advantage * mask, reduced to scalar
|
||||
3. fused_cosine_similarity: batched cosine similarity without intermediate tensors
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Fused log_softmax + gather (selective log softmax)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: log_softmax(logits, dim=-1) → (B, S, V) → gather(index=labels)
|
||||
# We compute: for each (b, s), the log_softmax value at logits[b, s, labels[b, s]]
|
||||
# This avoids materializing the full (B, S, V) log_softmax output.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_log_softmax_gather_kernel(
|
||||
logits_ptr, # (B*S, V) row-major
|
||||
labels_ptr, # (B*S,) int64
|
||||
output_ptr, # (B*S,) float32
|
||||
V: tl.constexpr, # vocab size
|
||||
BLOCK_V: tl.constexpr, # tile width over vocab
|
||||
):
|
||||
"""Compute log_softmax(logits)[label] for each row without materializing full output."""
|
||||
row = tl.program_id(0)
|
||||
|
||||
logits_row_ptr = logits_ptr + row * V
|
||||
label = tl.load(labels_ptr + row)
|
||||
|
||||
# Pass 1: find max for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
max_val = tl.maximum(max_val, tl.max(vals, axis=0))
|
||||
|
||||
# Pass 2: compute sum(exp(x - max))
|
||||
sum_exp = 0.0
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
v_offsets = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = v_offsets < V
|
||||
vals = tl.load(logits_row_ptr + v_offsets, mask=mask, other=-float("inf"))
|
||||
sum_exp += tl.sum(tl.exp(vals - max_val), axis=0)
|
||||
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Gather: log_softmax[label] = logits[label] - max - log_sum_exp
|
||||
target_logit = tl.load(logits_row_ptr + label)
|
||||
result = target_logit - max_val - log_sum_exp
|
||||
|
||||
tl.store(output_ptr + row, result)
|
||||
|
||||
|
||||
def fused_log_softmax_gather(
|
||||
logits: torch.Tensor, labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute log_softmax(logits, dim=-1).gather(-1, labels) without materializing full output.
|
||||
|
||||
Args:
|
||||
logits: (B, S, V) or (B*S, V) float tensor (bf16 or fp32)
|
||||
labels: (B, S) or (B*S,) int64 tensor of target indices
|
||||
|
||||
Returns:
|
||||
(B, S) or (B*S,) float32 tensor of selected log probabilities
|
||||
"""
|
||||
orig_shape = logits.shape[:-1]
|
||||
V = logits.shape[-1]
|
||||
logits_2d = logits.reshape(-1, V).contiguous()
|
||||
labels_1d = labels.reshape(-1).contiguous()
|
||||
n_rows = logits_2d.shape[0]
|
||||
|
||||
output = torch.empty(n_rows, device=logits.device, dtype=torch.float32)
|
||||
|
||||
# Choose BLOCK_V: must be power of 2, large enough for good occupancy
|
||||
BLOCK_V = min(triton.next_power_of_2(V), 65536)
|
||||
|
||||
_fused_log_softmax_gather_kernel[(n_rows,)](
|
||||
logits_2d,
|
||||
labels_1d,
|
||||
output,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Fused masked REINFORCE loss reduction
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: (-logp * adv * mask).sum() / mask.sum()
|
||||
# We do the masked multiply-accumulate in one kernel, returning (sum, count).
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_reinforce_loss_kernel(
|
||||
logps_ptr, # (N,) float32 per-token log probs
|
||||
advs_ptr, # (N,) float32 advantages
|
||||
mask_ptr, # (N,) bool action mask
|
||||
partial_sum_ptr, # (n_blocks,) partial sums
|
||||
partial_cnt_ptr, # (n_blocks,) partial counts
|
||||
N: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
block_id = tl.program_id(0)
|
||||
offsets = block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
valid = offsets < N
|
||||
|
||||
logps = tl.load(logps_ptr + offsets, mask=valid, other=0.0)
|
||||
advs = tl.load(advs_ptr + offsets, mask=valid, other=0.0)
|
||||
m = tl.load(mask_ptr + offsets, mask=valid, other=0).to(tl.float32)
|
||||
|
||||
# -logp * advantage * mask
|
||||
loss = -logps * advs * m
|
||||
block_sum = tl.sum(loss, axis=0)
|
||||
block_cnt = tl.sum(m, axis=0)
|
||||
|
||||
tl.store(partial_sum_ptr + block_id, block_sum)
|
||||
tl.store(partial_cnt_ptr + block_id, block_cnt)
|
||||
|
||||
|
||||
def fused_reinforce_loss(
|
||||
per_token_logps: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute masked REINFORCE loss: (-logp * adv * mask).sum() / mask.sum().
|
||||
|
||||
All inputs should be flat or will be flattened. Returns scalar loss tensor.
|
||||
"""
|
||||
logps_flat = per_token_logps.reshape(-1).contiguous()
|
||||
advs_flat = advantages.reshape(-1).contiguous()
|
||||
mask_flat = action_mask.reshape(-1).contiguous()
|
||||
N = logps_flat.shape[0]
|
||||
|
||||
BLOCK_N = 1024
|
||||
n_blocks = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
partial_sum = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
partial_cnt = torch.empty(n_blocks, device=logps_flat.device, dtype=torch.float32)
|
||||
|
||||
_fused_reinforce_loss_kernel[(n_blocks,)](
|
||||
logps_flat,
|
||||
advs_flat,
|
||||
mask_flat,
|
||||
partial_sum,
|
||||
partial_cnt,
|
||||
N=N,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
total_sum = partial_sum.sum()
|
||||
total_cnt = partial_cnt.sum().clamp(min=1)
|
||||
return total_sum / total_cnt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Fused cosine similarity (batched, for EBFT rewards)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: F.cosine_similarity(gen, gt, dim=-1) which normalizes then dots,
|
||||
# we fuse the dot product, norm computation, and division into one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_cosine_sim_kernel(
|
||||
a_ptr, # (N, D) first set of vectors
|
||||
b_ptr, # (N, D) second set of vectors
|
||||
out_ptr, # (N,) cosine similarities
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
a_row_ptr = a_ptr + row * D
|
||||
b_row_ptr = b_ptr + row * D
|
||||
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
mask = d_offsets < D
|
||||
a_vals = tl.load(a_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
b_vals = tl.load(b_row_ptr + d_offsets, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
norm_a += tl.sum(a_vals * a_vals, axis=0)
|
||||
norm_b += tl.sum(b_vals * b_vals, axis=0)
|
||||
|
||||
denom = tl.sqrt(norm_a) * tl.sqrt(norm_b)
|
||||
denom = tl.where(denom > 1e-8, denom, 1e-8)
|
||||
result = dot / denom
|
||||
|
||||
tl.store(out_ptr + row, result)
|
||||
|
||||
|
||||
def fused_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cosine similarity along the last dimension.
|
||||
|
||||
Args:
|
||||
a, b: (..., D) tensors of the same shape
|
||||
|
||||
Returns:
|
||||
(...,) tensor of cosine similarities
|
||||
"""
|
||||
orig_shape = a.shape[:-1]
|
||||
D = a.shape[-1]
|
||||
a_2d = a.reshape(-1, D).contiguous()
|
||||
b_2d = b.reshape(-1, D).contiguous()
|
||||
N = a_2d.shape[0]
|
||||
|
||||
output = torch.empty(N, device=a.device, dtype=torch.float32)
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_cosine_sim_kernel[(N,)](
|
||||
a_2d,
|
||||
b_2d,
|
||||
output,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output.view(orig_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Fused pairwise diversity penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instead of: bmm(gen, gen.T) → mask diagonal → sum / (n-1)
|
||||
# We compute the pairwise dot products and exclusion in one kernel.
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fused_diversity_kernel(
|
||||
emb_ptr, # (B, N, D) embeddings, row-major
|
||||
out_ptr, # (B, N) diversity penalties
|
||||
N: tl.constexpr, # n_samples
|
||||
D: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
"""For each (b, i), compute mean dot product to all j != i."""
|
||||
b = tl.program_id(0)
|
||||
i = tl.program_id(1)
|
||||
|
||||
# Pointer to emb[b, i, :]
|
||||
emb_bi_ptr = emb_ptr + (b * N + i) * D
|
||||
|
||||
total_sim = 0.0
|
||||
for j in range(N):
|
||||
emb_bj_ptr = emb_ptr + (b * N + j) * D
|
||||
|
||||
dot = 0.0
|
||||
for d_start in range(0, D, BLOCK_D):
|
||||
d_offsets = d_start + tl.arange(0, BLOCK_D)
|
||||
d_mask = d_offsets < D
|
||||
a_vals = tl.load(emb_bi_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
b_vals = tl.load(emb_bj_ptr + d_offsets, mask=d_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
dot += tl.sum(a_vals * b_vals, axis=0)
|
||||
|
||||
# Exclude self-similarity (j == i)
|
||||
is_other = j != i
|
||||
total_sim += dot * is_other
|
||||
|
||||
result = total_sim / (N - 1)
|
||||
tl.store(out_ptr + b * N + i, result)
|
||||
|
||||
|
||||
def fused_diversity_penalty(embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute mean pairwise dot product (excluding self) per sample.
|
||||
|
||||
Args:
|
||||
embeddings: (B, N, D) tensor where N is n_samples
|
||||
|
||||
Returns:
|
||||
(B, N) tensor of diversity penalties
|
||||
"""
|
||||
B, N, D = embeddings.shape
|
||||
embeddings = embeddings.contiguous()
|
||||
output = torch.zeros(B, N, device=embeddings.device, dtype=torch.float32)
|
||||
if N <= 1:
|
||||
return output # diversity is 0 when there's only one sample
|
||||
|
||||
BLOCK_D = min(triton.next_power_of_2(D), 4096)
|
||||
|
||||
_fused_diversity_kernel[(B, N)](
|
||||
embeddings,
|
||||
output,
|
||||
N=N,
|
||||
D=D,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1,264 +0,0 @@
|
||||
"""
|
||||
Feature-matching reward utilities for Energy-Based Fine-Tuning (EBFT).
|
||||
|
||||
Ported from: feature-002/ebft_openrlhf/openrlhf/utils/embedding_utils.py
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_hidden_states(
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_indices: list[int],
|
||||
batch_size: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through model, extracting and concatenating hidden states
|
||||
at specified layer indices.
|
||||
|
||||
Args:
|
||||
model: The frozen feature network
|
||||
input_ids: (B, S) token ids
|
||||
attention_mask: (B, S) attention mask
|
||||
layer_indices: List of layer indices to extract (e.g., [8, 16, 24] for 32-layer model)
|
||||
batch_size: If set, process in chunks to reduce peak memory
|
||||
|
||||
Returns:
|
||||
Concatenated hidden states: (B, S, num_layers * H)
|
||||
"""
|
||||
if batch_size is None:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Use the inner transformer body (skips lm_head) when available.
|
||||
# This avoids the expensive hidden_dim × vocab_size matmul whose
|
||||
# output (logits) is never used — only hidden_states are needed.
|
||||
body = getattr(model, "model", None)
|
||||
if body is not None and hasattr(body, "forward"):
|
||||
forward_model = body
|
||||
else:
|
||||
forward_model = model
|
||||
|
||||
all_features = []
|
||||
for i in range(0, input_ids.shape[0], batch_size):
|
||||
chunk_ids = input_ids[i : i + batch_size]
|
||||
chunk_mask = attention_mask[i : i + batch_size]
|
||||
|
||||
outputs = forward_model(
|
||||
chunk_ids,
|
||||
attention_mask=chunk_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# hidden_states is a tuple of (num_layers + 1) tensors, each (B, S, H)
|
||||
# index 0 is the embedding layer output
|
||||
hidden_states = outputs.hidden_states
|
||||
chunk_features = []
|
||||
for idx in layer_indices:
|
||||
chunk_features.append(hidden_states[idx])
|
||||
|
||||
# Concatenate across feature dimension: (B, S, num_layers * H)
|
||||
all_features.append(torch.cat(chunk_features, dim=-1))
|
||||
|
||||
return torch.cat(all_features, dim=0)
|
||||
|
||||
|
||||
def apply_embed_method(
|
||||
hidden_states: torch.Tensor,
|
||||
method: str,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
prompt_lengths: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pool per-token hidden states into per-sequence embeddings.
|
||||
|
||||
Args:
|
||||
hidden_states: (B, S, D) concatenated hidden states
|
||||
method: One of "last_token", "mean_pooling", "completion_mean", "concat"
|
||||
attention_mask: (B, S) mask for mean pooling
|
||||
prompt_lengths: (B,) number of prompt tokens per sample (for completion_mean)
|
||||
|
||||
Returns:
|
||||
Sequence embeddings: (B, D) for last_token/mean_pooling/completion_mean,
|
||||
(B, 3*D) for concat
|
||||
"""
|
||||
if method == "last_token":
|
||||
if attention_mask is not None:
|
||||
# Find last non-padding position per sample
|
||||
last_idx = attention_mask.sum(dim=1).long() - 1 # (B,)
|
||||
return hidden_states[torch.arange(hidden_states.shape[0]), last_idx]
|
||||
return hidden_states[:, -1, :]
|
||||
|
||||
if method == "mean_pooling":
|
||||
if attention_mask is not None:
|
||||
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
return hidden_states.mean(dim=1)
|
||||
|
||||
if method == "completion_mean":
|
||||
# Mean pool over completion tokens only (exclude prompt)
|
||||
if prompt_lengths is None:
|
||||
raise ValueError("completion_mean requires prompt_lengths")
|
||||
B, S, _ = hidden_states.shape
|
||||
positions = torch.arange(S, device=hidden_states.device).unsqueeze(0) # (1, S)
|
||||
comp_mask = positions >= prompt_lengths.unsqueeze(1) # (B, S)
|
||||
if attention_mask is not None:
|
||||
comp_mask = comp_mask & attention_mask.bool()
|
||||
mask = comp_mask.unsqueeze(-1).float() # (B, S, 1)
|
||||
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
||||
|
||||
if method == "concat":
|
||||
B, S, D = hidden_states.shape
|
||||
if attention_mask is not None:
|
||||
valid_lens = attention_mask.sum(dim=1).long() # (B,)
|
||||
else:
|
||||
valid_lens = torch.full(
|
||||
(B,), S, device=hidden_states.device, dtype=torch.long
|
||||
)
|
||||
# Compute quartile positions relative to valid length per sample
|
||||
# First valid position index for each sample (handles right-padding)
|
||||
q1 = (valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
q2 = (valid_lens // 2).clamp(min=0, max=S - 1)
|
||||
q3 = (3 * valid_lens // 4).clamp(min=0, max=S - 1)
|
||||
batch_idx = torch.arange(B, device=hidden_states.device)
|
||||
return torch.cat(
|
||||
[
|
||||
hidden_states[batch_idx, q1],
|
||||
hidden_states[batch_idx, q2],
|
||||
hidden_states[batch_idx, q3],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown embed_method: {method}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_alignment_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
gt_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute alignment reward as cosine similarity between generated
|
||||
and ground-truth feature embeddings.
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated sequence embeddings
|
||||
gt_embedding: (B, D) ground-truth sequence embeddings
|
||||
If num_generations > 1, gt_embedding should be repeated
|
||||
to match gen_embedding's batch dim.
|
||||
|
||||
Returns:
|
||||
Alignment rewards: (B,) cosine similarities in [-1, 1]
|
||||
"""
|
||||
return F.cosine_similarity(gen_embedding, gt_embedding, dim=-1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_diversity_rewards(
|
||||
gen_embedding: torch.Tensor,
|
||||
num_generations: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute diversity penalty as mean pairwise dot-product similarity
|
||||
between samples from the same prompt (excluding self-similarity).
|
||||
|
||||
Args:
|
||||
gen_embedding: (B, D) generated embeddings where B = num_prompts * num_generations
|
||||
num_generations: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Diversity penalties: (B,) mean similarity to other samples from same prompt
|
||||
"""
|
||||
if num_generations <= 1:
|
||||
return torch.zeros(gen_embedding.shape[0], device=gen_embedding.device)
|
||||
|
||||
num_prompts = gen_embedding.shape[0] // num_generations
|
||||
|
||||
# Reshape to (num_prompts, num_generations, D)
|
||||
reshaped = gen_embedding.view(num_prompts, num_generations, -1)
|
||||
|
||||
# Pairwise dot products within each group: (num_prompts, num_generations, num_generations)
|
||||
sims = torch.bmm(reshaped, reshaped.transpose(1, 2))
|
||||
|
||||
# Zero out self-similarity (diagonal)
|
||||
eye = torch.eye(num_generations, device=sims.device, dtype=torch.bool)
|
||||
sims = sims.masked_fill(eye.unsqueeze(0), 0.0)
|
||||
|
||||
# Mean similarity to other samples: (num_prompts, num_generations)
|
||||
diversity = sims.sum(dim=-1) / (num_generations - 1)
|
||||
|
||||
# Flatten back to (B,)
|
||||
return diversity.view(-1)
|
||||
|
||||
|
||||
def whiten_embeddings_batched(
|
||||
phi: torch.Tensor,
|
||||
phi_gt: torch.Tensor,
|
||||
whiten_tol: float = 1e-5,
|
||||
normalize: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Whiten generated embeddings using SVD, then apply same transform to ground-truth.
|
||||
|
||||
Whitening decorrelates feature dimensions so no single direction dominates
|
||||
the feature-matching loss. Uses pseudo-inverse for rank-deficient cases.
|
||||
|
||||
Note: Singular values scale with sqrt(B), so reward magnitudes are
|
||||
batch-size dependent. This is acceptable because B = n_samples_per_prompt
|
||||
which is fixed during training (typically 2-4).
|
||||
|
||||
Args:
|
||||
phi: (B, D) generated embeddings (used to estimate covariance)
|
||||
phi_gt: (B, D) ground-truth embeddings
|
||||
whiten_tol: Tolerance for singular value cutoff
|
||||
normalize: If True, L2-normalize after whitening
|
||||
|
||||
Returns:
|
||||
Whitened (phi, phi_gt) tuple, each (B, D)
|
||||
"""
|
||||
phi_f = phi.float()
|
||||
phi_gt_f = phi_gt.float()
|
||||
|
||||
# Feature-space SVD: operate on phi_f.T (D, B) so U is (D, D)
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(phi_f.T.unsqueeze(0), full_matrices=False)
|
||||
except torch._C._LinAlgError:
|
||||
# Fallback: add small noise
|
||||
noise = 1e-6 * phi_f.abs().mean()
|
||||
try:
|
||||
U, S, _ = torch.linalg.svd(
|
||||
(phi_f.T + noise * torch.randn_like(phi_f.T)).unsqueeze(0),
|
||||
full_matrices=False,
|
||||
)
|
||||
except torch._C._LinAlgError:
|
||||
if normalize:
|
||||
return (
|
||||
F.normalize(phi, p=2, dim=-1),
|
||||
F.normalize(phi_gt, p=2, dim=-1),
|
||||
)
|
||||
return phi, phi_gt
|
||||
|
||||
U, S = U.squeeze(0), S.squeeze(0) # U: (D, min(D,B)), S: (min(D,B),)
|
||||
|
||||
# Safe inverse of singular values
|
||||
s_max = S.max()
|
||||
inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S))
|
||||
|
||||
# W = U @ diag(inv_s) @ U^T — feature-space whitening matrix (D, D)
|
||||
W = (U * inv_s.unsqueeze(0)) @ U.T # (D, D)
|
||||
phi_w = (phi_f @ W).to(phi.dtype) # (B, D)
|
||||
phi_gt_w = (phi_gt_f @ W).to(phi_gt.dtype) # (B, D)
|
||||
|
||||
if normalize:
|
||||
phi_w = F.normalize(phi_w, p=2, dim=-1)
|
||||
phi_gt_w = F.normalize(phi_gt_w, p=2, dim=-1)
|
||||
|
||||
return phi_w, phi_gt_w
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,531 +0,0 @@
|
||||
"""
|
||||
EBFT Trainer — Energy-Based Fine-Tuning integrated via GRPOTrainer.
|
||||
|
||||
Extends AxolotlGRPOTrainer by plugging feature-matching rewards into
|
||||
the standard GRPO reward function interface.
|
||||
|
||||
Paper: "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"
|
||||
(Jelassi et al., 2026) https://arxiv.org/abs/2603.12248
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
|
||||
|
||||
from axolotl.core.trainers.ebft.args import AxolotlEBFTConfig
|
||||
from axolotl.core.trainers.ebft.rewards import (
|
||||
apply_embed_method,
|
||||
extract_hidden_states,
|
||||
get_alignment_rewards,
|
||||
get_diversity_rewards,
|
||||
whiten_embeddings_batched,
|
||||
)
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
|
||||
from accelerate import Accelerator
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EBFTMixin:
|
||||
"""
|
||||
Mixin that adds EBFT feature-matching reward logic to any GRPO-based trainer.
|
||||
|
||||
Provides:
|
||||
- Frozen feature network setup (shared weights for PEFT, deepcopy otherwise)
|
||||
- _feature_matching_reward() callable for GRPO reward function interface
|
||||
- _sequential_rollout() for multi-turn conversations
|
||||
"""
|
||||
|
||||
# Type stubs for attributes provided by the composed GRPOTrainer base class.
|
||||
# These are not defined here but accessed via cooperative multiple inheritance.
|
||||
if TYPE_CHECKING:
|
||||
accelerator: Accelerator
|
||||
model: PreTrainedModel
|
||||
args: AxolotlEBFTConfig
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
num_generations: int
|
||||
vllm_generation: VLLMGeneration
|
||||
_metrics: defaultdict
|
||||
|
||||
_tag_names = ["trl", "ebft", "axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | PreTrainedModel,
|
||||
args: AxolotlEBFTConfig | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset
|
||||
| IterableDataset
|
||||
| dict[str, Dataset | IterableDataset]
|
||||
| None = None,
|
||||
processing_class: PreTrainedTokenizerBase | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[
|
||||
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
|
||||
] = (None, None),
|
||||
peft_config: Any | None = None,
|
||||
):
|
||||
# Pass our feature-matching reward function to GRPOTrainer
|
||||
# It will be called with (prompts, completions, **kwargs) where
|
||||
# kwargs includes all extra dataset fields like "ground_truth"
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
model=model,
|
||||
reward_funcs=[self._feature_matching_reward],
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
assert args is not None
|
||||
|
||||
# --- Feature network setup ---
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
# Check for PEFT model — use hasattr for robustness across DDP/FSDP wrapping
|
||||
self._share_feature_weights = isinstance(unwrapped, PeftModel) or hasattr(
|
||||
unwrapped, "disable_adapter"
|
||||
)
|
||||
|
||||
if self._share_feature_weights:
|
||||
# Share weights: use actor's base model with adapters disabled.
|
||||
# Saves a full model copy (~8 GB for 4B model).
|
||||
self.feature_network = None
|
||||
param_gb = sum(p.numel() for p in unwrapped.parameters()) * 2 / 1e9
|
||||
LOG.info(
|
||||
f"EBFT feature network shares actor weights (PEFT disable_adapter). "
|
||||
f"Saving ~{param_gb:.1f} GB"
|
||||
)
|
||||
else:
|
||||
LOG.info("Creating frozen feature network for EBFT (deepcopy)...")
|
||||
self.feature_network = copy.deepcopy(unwrapped)
|
||||
for param in self.feature_network.parameters():
|
||||
param.requires_grad = False
|
||||
self.feature_network.eval()
|
||||
|
||||
# Compute layer indices from fractional depths
|
||||
# Handle VLM models where num_hidden_layers is on text_config
|
||||
config = unwrapped.config
|
||||
if hasattr(config, "text_config") and hasattr(
|
||||
config.text_config, "num_hidden_layers"
|
||||
):
|
||||
config = config.text_config
|
||||
num_layers = config.num_hidden_layers
|
||||
self.feature_layer_indices = [
|
||||
int(frac * num_layers) for frac in args.ebft_feature_layers
|
||||
]
|
||||
LOG.info(
|
||||
f"EBFT feature extraction from layers {self.feature_layer_indices} "
|
||||
f"(of {num_layers} total), embed_method={args.ebft_embed_method}"
|
||||
)
|
||||
if args.ebft_adaptive_max_tokens:
|
||||
LOG.info(
|
||||
f"EBFT adaptive max_tokens enabled "
|
||||
f"(gt_length_multiplier={args.ebft_gt_length_multiplier})"
|
||||
)
|
||||
|
||||
_adaptive_max_lock = None # initialized lazily
|
||||
|
||||
def _generate_only(self, inputs, rank0_only=False):
|
||||
"""Override to set per-batch max_tokens based on ground-truth length.
|
||||
|
||||
Uses a lock to prevent race conditions in async mode where concurrent
|
||||
BG threads could interleave mutations of max_completion_length.
|
||||
"""
|
||||
import threading
|
||||
|
||||
args = self.args
|
||||
if (
|
||||
args.ebft_adaptive_max_tokens
|
||||
and hasattr(self, "vllm_generation")
|
||||
and inputs
|
||||
):
|
||||
gt_texts = [
|
||||
x.get("ground_truth", "") for x in inputs if x.get("ground_truth")
|
||||
]
|
||||
if gt_texts:
|
||||
gt_token_counts = [
|
||||
len(self.processing_class.encode(gt, add_special_tokens=False))
|
||||
for gt in gt_texts
|
||||
]
|
||||
multiplier = args.ebft_gt_length_multiplier
|
||||
max_completion = self.vllm_generation.max_completion_length
|
||||
adaptive_max = max(
|
||||
min(int(c * multiplier), max_completion) for c in gt_token_counts
|
||||
)
|
||||
adaptive_max = max(adaptive_max, 64)
|
||||
|
||||
if self._adaptive_max_lock is None:
|
||||
self._adaptive_max_lock = threading.Lock()
|
||||
with self._adaptive_max_lock:
|
||||
original = self.vllm_generation.max_completion_length
|
||||
self.vllm_generation.max_completion_length = adaptive_max
|
||||
try:
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
finally:
|
||||
self.vllm_generation.max_completion_length = original
|
||||
|
||||
return super()._generate_only(inputs, rank0_only)
|
||||
|
||||
@torch.no_grad()
|
||||
def _feature_matching_reward(
|
||||
self,
|
||||
prompts: list,
|
||||
completions: list,
|
||||
ground_truth: list[str] | None = None,
|
||||
remaining_turns: list | None = None,
|
||||
**kwargs,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Compute feature-matching rewards for generated completions.
|
||||
|
||||
This is called by GRPOTrainer's _generate_and_score_completions()
|
||||
as a standard reward function. The `ground_truth` field comes from
|
||||
the dataset via reward_kwargs.
|
||||
|
||||
For multi-turn conversations, `remaining_turns` contains the subsequent
|
||||
user/assistant turn pairs. When present, we do sequential rollouts:
|
||||
generate each assistant turn conditioned on history + previous generations,
|
||||
then compute feature-matching rewards on the full generated conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt strings/messages
|
||||
completions: List of generated completion strings
|
||||
ground_truth: List of reference completion strings (from dataset)
|
||||
remaining_turns: List of remaining conversation turns after the
|
||||
first assistant turn (for multi-turn rollouts)
|
||||
|
||||
Returns:
|
||||
List of scalar rewards, one per completion
|
||||
"""
|
||||
if ground_truth is None:
|
||||
LOG.warning("No ground_truth field in dataset — using zero rewards")
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
device = self.accelerator.device
|
||||
args = self.args
|
||||
num_gens = self.num_generations
|
||||
|
||||
# --- Multi-turn sequential rollout ---
|
||||
# If remaining_turns is provided, generate subsequent assistant turns
|
||||
# by calling vLLM for each turn, building up the full conversation.
|
||||
if remaining_turns is not None and hasattr(self, "vllm_generation"):
|
||||
completions = self._sequential_rollout(
|
||||
prompts, completions, remaining_turns, num_gens
|
||||
)
|
||||
|
||||
# --- Tokenize generated sequences: prompt + completion ---
|
||||
gen_texts = []
|
||||
gen_prompt_texts = []
|
||||
for p, c in zip(prompts, completions, strict=True):
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
prompt_text = p
|
||||
if isinstance(c, list):
|
||||
comp_text = c[0].get("content", "") if c else ""
|
||||
else:
|
||||
comp_text = c
|
||||
gen_texts.append(prompt_text + comp_text)
|
||||
gen_prompt_texts.append(prompt_text)
|
||||
|
||||
gen_encoded = self.processing_class(
|
||||
text=gen_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gen_ids = gen_encoded["input_ids"].to(device)
|
||||
gen_mask = gen_encoded["attention_mask"].to(device)
|
||||
|
||||
# Compute prompt lengths for completion_mean pooling
|
||||
gen_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gen_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Tokenize ground-truth sequences: prompt + ground_truth ---
|
||||
# For multi-turn (remaining_turns present), render the full GT conversation
|
||||
# through the chat template to preserve role markers between turns.
|
||||
gt_texts = []
|
||||
gt_prompt_texts = []
|
||||
for i, (p, gt) in enumerate(zip(prompts, ground_truth, strict=True)):
|
||||
if i % num_gens != 0:
|
||||
continue # Only need one GT per prompt group
|
||||
if isinstance(p, list):
|
||||
prompt_text = self.processing_class.apply_chat_template(
|
||||
p, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
# Multi-turn: build full GT conversation with remaining turns
|
||||
if remaining_turns is not None:
|
||||
prompt_idx = i // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx]
|
||||
if prompt_idx < len(remaining_turns)
|
||||
else []
|
||||
)
|
||||
if turns:
|
||||
gt_conv = list(p) + [{"role": "assistant", "content": gt}]
|
||||
gt_conv.extend(turns)
|
||||
full_gt_text = self.processing_class.apply_chat_template(
|
||||
gt_conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
gt_texts.append(full_gt_text)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
continue
|
||||
else:
|
||||
prompt_text = p
|
||||
gt_texts.append(prompt_text + gt)
|
||||
gt_prompt_texts.append(prompt_text)
|
||||
|
||||
gt_encoded = self.processing_class(
|
||||
text=gt_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=getattr(self.args, "max_length", None)
|
||||
or getattr(self.args, "max_seq_length", None)
|
||||
or 2048,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
gt_ids = gt_encoded["input_ids"].to(device)
|
||||
gt_mask = gt_encoded["attention_mask"].to(device)
|
||||
|
||||
gt_prompt_lengths = torch.tensor(
|
||||
[
|
||||
len(self.processing_class.encode(pt, add_special_tokens=False))
|
||||
for pt in gt_prompt_texts
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --- Extract features from frozen feature network ---
|
||||
# INVARIANT: disable_adapter() yields the unmodified base weights because
|
||||
# _sync_peft_weights_no_merge and _sync_lora_adapter never call
|
||||
# merge_adapter() — they compute merged weights as new tensors or save
|
||||
# the adapter to filesystem. Base weights are never modified in-place.
|
||||
if self._share_feature_weights:
|
||||
unwrapped = self.accelerator.unwrap_model(self.model)
|
||||
feature_ctx = unwrapped.disable_adapter()
|
||||
else:
|
||||
unwrapped = self.feature_network
|
||||
feature_ctx = contextlib.nullcontext()
|
||||
|
||||
with feature_ctx:
|
||||
was_training = unwrapped.training
|
||||
unwrapped.eval()
|
||||
gen_hidden = extract_hidden_states(
|
||||
unwrapped, gen_ids, gen_mask, self.feature_layer_indices
|
||||
)
|
||||
gt_hidden = extract_hidden_states(
|
||||
unwrapped, gt_ids, gt_mask, self.feature_layer_indices
|
||||
)
|
||||
if was_training:
|
||||
unwrapped.train()
|
||||
|
||||
# --- Pool to sequence-level embeddings ---
|
||||
gen_emb = apply_embed_method(
|
||||
gen_hidden,
|
||||
args.ebft_embed_method,
|
||||
gen_mask,
|
||||
prompt_lengths=gen_prompt_lengths,
|
||||
)
|
||||
gt_emb = apply_embed_method(
|
||||
gt_hidden,
|
||||
args.ebft_embed_method,
|
||||
gt_mask,
|
||||
prompt_lengths=gt_prompt_lengths,
|
||||
)
|
||||
|
||||
# --- Optional whitening ---
|
||||
batch_size = gen_emb.shape[0]
|
||||
if args.ebft_use_whitening and batch_size > 1:
|
||||
num_prompts = batch_size // num_gens
|
||||
gen_reshaped = gen_emb.view(num_prompts, num_gens, -1)
|
||||
whitened_gen_list = []
|
||||
whitened_gt_list = []
|
||||
for i in range(num_prompts):
|
||||
w_gen, w_gt = whiten_embeddings_batched(
|
||||
gen_reshaped[i], gt_emb[i : i + 1]
|
||||
)
|
||||
whitened_gen_list.append(w_gen)
|
||||
whitened_gt_list.append(w_gt)
|
||||
gen_emb = torch.cat(whitened_gen_list, dim=0)
|
||||
gt_emb = torch.cat(whitened_gt_list, dim=0)
|
||||
else:
|
||||
gen_emb = torch.nn.functional.normalize(gen_emb, p=2, dim=-1)
|
||||
gt_emb = torch.nn.functional.normalize(gt_emb, p=2, dim=-1)
|
||||
|
||||
# Repeat gt_emb: each GT repeated num_generations times
|
||||
gt_emb_expanded = gt_emb.repeat_interleave(num_gens, dim=0)
|
||||
|
||||
# --- Compute rewards ---
|
||||
alignment = get_alignment_rewards(gen_emb, gt_emb_expanded)
|
||||
diversity = get_diversity_rewards(gen_emb, num_gens)
|
||||
|
||||
# Scale by 2 per paper equation (7):
|
||||
# r_j = 2*φ(ŷ_j)^T*φ(y) - 2/(n-1) * Σ_{j'≠j} φ(ŷ_j)^T*φ(ŷ_{j'})
|
||||
alignment = alignment * 2
|
||||
diversity = diversity * 2
|
||||
|
||||
rewards = (
|
||||
args.ebft_alignment_coef * alignment - args.ebft_diversity_coef * diversity
|
||||
)
|
||||
|
||||
# Compute CFM loss: ||E[φ(ŷ)] - φ(y)||^2 (paper eq 2)
|
||||
gen_reshaped = gen_emb.view(-1, num_gens, gen_emb.shape[-1])
|
||||
mean_gen = gen_reshaped.mean(dim=1) # (num_prompts, D)
|
||||
cfm_loss = ((mean_gen - gt_emb) ** 2).sum(dim=-1).mean()
|
||||
|
||||
# Log feature-matching metrics to console and wandb
|
||||
_align = alignment.mean().item()
|
||||
_divers = diversity.mean().item()
|
||||
_reward = rewards.mean().item()
|
||||
_cfm = cfm_loss.item()
|
||||
|
||||
LOG.info(
|
||||
f"ebft reward | "
|
||||
f"align {_align:+.3f} ^ | "
|
||||
f"divers {_divers:+.3f} v | "
|
||||
f"cfm {_cfm:.3f} v | "
|
||||
f"reward {_reward:+.3f} ^"
|
||||
)
|
||||
|
||||
# Log to wandb via trainer's _metrics (picked up by GRPO's logging)
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if hasattr(self, "_metrics"):
|
||||
self._metrics[mode]["ebft/alignment"].append(_align)
|
||||
self._metrics[mode]["ebft/diversity"].append(_divers)
|
||||
self._metrics[mode]["ebft/cfm_loss"].append(_cfm)
|
||||
self._metrics[mode]["ebft/reward"].append(_reward)
|
||||
|
||||
return rewards.cpu().tolist()
|
||||
|
||||
@torch.no_grad()
|
||||
def _sequential_rollout(
|
||||
self,
|
||||
prompts: list,
|
||||
first_completions: list,
|
||||
remaining_turns: list,
|
||||
num_gens: int,
|
||||
) -> list:
|
||||
"""
|
||||
Extend single-turn completions into multi-turn conversations.
|
||||
|
||||
For each prompt group, takes the first generated assistant turn and
|
||||
sequentially generates subsequent assistant turns by calling vLLM,
|
||||
building up a full multi-turn conversation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt message lists (repeated num_gens times)
|
||||
first_completions: List of generated first-turn completions
|
||||
remaining_turns: List of remaining turn pairs after first assistant turn.
|
||||
Each element is a list of dicts: [{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "...GT..."}]
|
||||
num_gens: Number of generations per prompt
|
||||
|
||||
Returns:
|
||||
Extended completions incorporating all generated turns
|
||||
"""
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
max_tokens = getattr(self.args, "max_completion_length", 256)
|
||||
temperature = getattr(self.args, "temperature", 0.7)
|
||||
gen_kwargs = getattr(self.args, "generation_kwargs", None) or {}
|
||||
|
||||
extended_completions = []
|
||||
|
||||
for idx in range(len(prompts)):
|
||||
prompt_msgs = prompts[idx] if isinstance(prompts[idx], list) else []
|
||||
first_comp = first_completions[idx]
|
||||
|
||||
# Extract first completion text
|
||||
if isinstance(first_comp, list):
|
||||
first_text = first_comp[0].get("content", "") if first_comp else ""
|
||||
else:
|
||||
first_text = first_comp
|
||||
|
||||
# Get remaining turns for this prompt (same for all num_gens copies)
|
||||
prompt_idx = idx // num_gens
|
||||
turns = (
|
||||
remaining_turns[prompt_idx] if prompt_idx < len(remaining_turns) else []
|
||||
)
|
||||
|
||||
if not turns:
|
||||
extended_completions.append(first_text)
|
||||
continue
|
||||
|
||||
# Build conversation with generated first turn
|
||||
conv = list(prompt_msgs) + [{"role": "assistant", "content": first_text}]
|
||||
|
||||
# Generate subsequent turns
|
||||
for turn in turns:
|
||||
if turn["role"] == "user":
|
||||
conv.append(turn)
|
||||
elif turn["role"] == "assistant":
|
||||
try:
|
||||
result = vllm_client.chat(
|
||||
messages=[conv],
|
||||
n=1,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
generation_kwargs=gen_kwargs,
|
||||
)
|
||||
gen_ids = result.get("completion_ids", [[]])[0]
|
||||
gen_text = self.processing_class.decode(
|
||||
gen_ids, skip_special_tokens=True
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Multi-turn rollout generation failed: {e}")
|
||||
gen_text = ""
|
||||
|
||||
conv.append({"role": "assistant", "content": gen_text})
|
||||
|
||||
# Render full conversation through chat template, then extract
|
||||
# everything after the original prompt as the "completion" text.
|
||||
# This preserves role markers and formatting between turns.
|
||||
full_rendered = self.processing_class.apply_chat_template(
|
||||
conv, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
prompt_rendered = self.processing_class.apply_chat_template(
|
||||
prompt_msgs, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
completion_text = full_rendered[len(prompt_rendered) :]
|
||||
extended_completions.append(completion_text)
|
||||
|
||||
return extended_completions
|
||||
|
||||
|
||||
class AxolotlEBFTTrainer(EBFTMixin, AxolotlGRPOTrainer):
|
||||
"""EBFT trainer using synchronous GRPO (standard vLLM generation)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AxolotlAsyncEBFTTrainer(EBFTMixin, AxolotlAsyncGRPOTrainer):
|
||||
"""EBFT trainer using async GRPO (prefetches next batch during training)."""
|
||||
|
||||
pass
|
||||
@@ -29,7 +29,7 @@ class GRPOStrategy:
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls,
|
||||
sequence_parallel: bool = False,
|
||||
sequence_parallel: bool,
|
||||
async_grpo: bool = False,
|
||||
) -> (
|
||||
type[AxolotlGRPOTrainer]
|
||||
@@ -88,8 +88,6 @@ class GRPOStrategy:
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
if trl.generation_batch_size is not None:
|
||||
grpo_args_kwargs["generation_batch_size"] = trl.generation_batch_size
|
||||
|
||||
if trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
@@ -201,10 +199,6 @@ class GRPOStrategy:
|
||||
if getattr(trl, "vllm_lora_sync", None) is not None:
|
||||
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
# Batch flattening (top-level config, not under trl)
|
||||
if getattr(cfg, "batch_flattening", None):
|
||||
grpo_args_kwargs["batch_flattening"] = cfg.batch_flattening
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -32,7 +32,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
from trl.trainer import GRPOConfig, GRPOTrainer
|
||||
@@ -130,18 +129,6 @@ class AsyncGRPOConfig(GRPOConfig):
|
||||
},
|
||||
)
|
||||
|
||||
# --- Batch flattening ---
|
||||
batch_flattening: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use batch flattening for the scoring forward pass. Removes padding tokens "
|
||||
"before the forward pass, reducing attention FLOPs proportional to the padding ratio. "
|
||||
"Requires flash_attention_2 attention implementation. Incompatible with FSDP and "
|
||||
"multimodal models. The per-token logprob results differ by bf16 precision (~0.03 mean) "
|
||||
"but produce equivalent loss and gradients."
|
||||
},
|
||||
)
|
||||
|
||||
# --- Streaming scoring ---
|
||||
streaming_partial_batch: bool = field(
|
||||
default=False,
|
||||
@@ -536,10 +523,7 @@ class GRPODataProducer(BaseDataProducer):
|
||||
def set_trainer(self, trainer) -> None:
|
||||
"""Inject the live trainer reference and create the prompt DataLoader."""
|
||||
self._trainer = trainer
|
||||
# Defer _init_prompt_dataloader if trainer.args is not yet set
|
||||
# (happens when set_trainer is called from _create_data_producer during __init__)
|
||||
if getattr(trainer, "args", None) is not None:
|
||||
self._init_prompt_dataloader()
|
||||
self._init_prompt_dataloader()
|
||||
|
||||
def _init_prompt_dataloader(self) -> None:
|
||||
from functools import partial
|
||||
@@ -596,10 +580,6 @@ class GRPODataProducer(BaseDataProducer):
|
||||
**kwargs,
|
||||
) -> RolloutDataset | None:
|
||||
"""Generate a fresh GRPO training rollout."""
|
||||
# Lazy init: create prompt DataLoader if deferred from set_trainer
|
||||
if self._prompt_dl is None and self._trainer is not None:
|
||||
self._init_prompt_dataloader()
|
||||
|
||||
is_main = self._trainer.accelerator.is_main_process
|
||||
|
||||
# FSDP rank0-only mode: non-rank-0 returns None (broadcast fills it later)
|
||||
@@ -648,21 +628,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Skip NCCL communicator init when using LoRA sync (filesystem) or HTTP-only
|
||||
# merged weight sync. NCCL is only needed for the standard update_named_param
|
||||
# path which broadcasts tensors through the communicator.
|
||||
# When using native LoRA sync, skip the NCCL communicator init in VLLMGeneration.
|
||||
# The communicator is not needed because weight sync happens via filesystem + HTTP,
|
||||
# and it fails when vLLM and a trainer rank share the same CUDA device.
|
||||
training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
|
||||
_skip_nccl = False
|
||||
if training_args is not None:
|
||||
if getattr(training_args, "vllm_lora_sync", False):
|
||||
_skip_nccl = True # LoRA sync uses filesystem + HTTP
|
||||
elif getattr(training_args, "async_prefetch", False):
|
||||
# Skip NCCL at init to avoid DDP param count mismatch in multi-GPU.
|
||||
# init_communicator allocates device tensors on rank 0 only, which
|
||||
# causes DDP to see different param counts across ranks.
|
||||
# The communicator is initialized lazily on first weight sync instead.
|
||||
_skip_nccl = True
|
||||
if _skip_nccl:
|
||||
if training_args is not None and getattr(
|
||||
training_args, "vllm_lora_sync", False
|
||||
):
|
||||
from trl.generation.vllm_generation import VLLMGeneration
|
||||
|
||||
_orig_init_vllm = VLLMGeneration._init_vllm
|
||||
@@ -689,12 +661,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
VLLMGeneration._init_vllm = _init_vllm_no_communicator
|
||||
|
||||
try:
|
||||
super().__init__(*args, **kwargs)
|
||||
finally:
|
||||
# Restore original _init_vllm so other trainers aren't affected
|
||||
if _skip_nccl:
|
||||
VLLMGeneration._init_vllm = _orig_init_vllm # type: ignore[possibly-undefined]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# FP8 models: zero out the pad token embedding so that padding
|
||||
# positions have zero hidden states throughout the network.
|
||||
@@ -813,50 +780,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._executor = None
|
||||
|
||||
def _submit_generation(self):
|
||||
"""Submit the next background generation job.
|
||||
|
||||
With multi-process (DDP/FSDP), only rank 0 generates to avoid
|
||||
cross-rank NCCL collectives from background threads. Non-rank-0
|
||||
processes enqueue a sentinel ``None`` that is replaced by a
|
||||
broadcast in ``_prepare_inputs_legacy_async``.
|
||||
"""
|
||||
rank0_only = self.accelerator.num_processes > 1
|
||||
if rank0_only and not self.accelerator.is_main_process:
|
||||
# Non-rank-0: nothing to generate; enqueue a resolved None future
|
||||
f: concurrent.futures.Future = concurrent.futures.Future()
|
||||
f.set_result(None)
|
||||
self._async_queue.put(f)
|
||||
return
|
||||
"""Submit the next background generation job."""
|
||||
batch = next(self._prompt_iter)
|
||||
future = self._executor.submit(self._generate_only, batch, rank0_only)
|
||||
future = self._executor.submit(self._generate_only, batch)
|
||||
self._async_queue.put(future)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcast rollout (legacy async, multi-process)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_rollout(self, rollout: dict | None) -> dict:
|
||||
"""Broadcast a rank0-only rollout dict to all ranks (main thread).
|
||||
|
||||
Rank 0 has the full rollout dict from ``_generate_only``; other ranks
|
||||
have ``None``. After broadcast, tensors are moved to each rank's
|
||||
local device.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
obj_list = [rollout if self.accelerator.is_main_process else None]
|
||||
dist.broadcast_object_list(obj_list, src=0)
|
||||
rollout = obj_list[0]
|
||||
assert rollout is not None, "broadcast_object_list failed to deliver rollout"
|
||||
|
||||
# Move tensors to local device (broadcast deserializes to CPU)
|
||||
device = self.accelerator.device
|
||||
for key, val in rollout.items():
|
||||
if isinstance(val, torch.Tensor) and val.device != device:
|
||||
rollout[key] = val.to(device)
|
||||
|
||||
return rollout
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Weight sync
|
||||
# ------------------------------------------------------------------
|
||||
@@ -868,18 +796,14 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
for Float8), and also safe for concurrent use since it never modifies base
|
||||
weights in-place.
|
||||
"""
|
||||
accelerator = self.vllm_generation.accelerator
|
||||
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
|
||||
return
|
||||
|
||||
# In multi-GPU async mode, we skip NCCL communicator init to avoid
|
||||
# DDP param count mismatch and NCCL device conflicts. Weight sync
|
||||
# uses the HTTP-only fallback in batch_update_named_params instead.
|
||||
|
||||
model = self.vllm_generation.model
|
||||
accelerator = self.vllm_generation.accelerator
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
fix_name = self.vllm_generation._fix_param_name_to_vllm
|
||||
|
||||
if not (self.vllm_generation.mode == "server" and accelerator.is_main_process):
|
||||
return
|
||||
|
||||
# Build lookup: module_path -> (A, B, scaling) for all active LoRA layers
|
||||
lora_info = {}
|
||||
for mod_name, module in model.base_model.model.named_modules():
|
||||
@@ -902,11 +826,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
weight_name = pname.replace(".weight_scale_inv", ".weight")
|
||||
scale_inv_lookup[weight_name] = pparam.data
|
||||
|
||||
# Only sync parameters that have LoRA modifications — skip unchanged
|
||||
# base weights to avoid OOM on the vLLM GPU from allocating the entire
|
||||
# model's worth of NCCL receive buffers.
|
||||
# Iterate all parameters, computing merged weights for LoRA layers.
|
||||
# Skip LoRA-specific params and FP8 scale params (scales will be
|
||||
# recomputed by vLLM when it receives the merged bf16 weight).
|
||||
params_to_sync = []
|
||||
compute_dtype = torch.bfloat16
|
||||
for name, param in model.named_parameters():
|
||||
vllm_name = name.removeprefix("base_model.model.").replace(
|
||||
".base_layer", ""
|
||||
@@ -915,58 +838,52 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
continue
|
||||
if "original_module" in vllm_name:
|
||||
continue
|
||||
# Skip FP8 quantization scale parameters - they are recomputed
|
||||
# on the vLLM side when we update the weight itself
|
||||
if "weight_scale_inv" in vllm_name or "input_scale" in vllm_name:
|
||||
continue
|
||||
if not vllm_name.endswith(".weight"):
|
||||
continue
|
||||
# fix_name strips modules_to_save.default. prefix
|
||||
raw_mod_path = vllm_name[: -len(".weight")]
|
||||
vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."])
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
|
||||
# Sync weights that have LoRA adapters OR are modules_to_save
|
||||
is_lora = mod_path in lora_info
|
||||
is_modules_to_save = raw_mod_path != mod_path # fix_name stripped a prefix
|
||||
if not is_lora and not is_modules_to_save:
|
||||
continue
|
||||
|
||||
data = param.data
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr
|
||||
bc = fp8_bf16.shape[1] // sc
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
data = data.to(compute_dtype)
|
||||
if vllm_name.endswith(".weight"):
|
||||
# Dequantize FP8 weights before merging
|
||||
if data.dtype == torch.float8_e4m3fn and name in scale_inv_lookup:
|
||||
scale_inv = scale_inv_lookup[name]
|
||||
# Block dequantization: weight * scale_inv (with broadcasting)
|
||||
fp8_bf16 = data.to(compute_dtype)
|
||||
if scale_inv.dim() == 2 and fp8_bf16.dim() == 2:
|
||||
# Block-quantized: scale_inv shape (rows/block, cols/block)
|
||||
sr, sc = scale_inv.shape
|
||||
br = fp8_bf16.shape[0] // sr # block height
|
||||
bc = fp8_bf16.shape[1] // sc # block width
|
||||
# Reshape → multiply by block scale → reshape back
|
||||
data = (
|
||||
fp8_bf16.reshape(sr, br, sc, bc)
|
||||
* scale_inv[:, None, :, None].to(compute_dtype)
|
||||
).reshape(fp8_bf16.shape)
|
||||
elif scale_inv.dim() <= 1:
|
||||
# Per-tensor or per-channel scale
|
||||
data = fp8_bf16 * scale_inv.to(compute_dtype)
|
||||
else:
|
||||
data = fp8_bf16
|
||||
elif data.dtype == torch.float8_e4m3fn:
|
||||
# FP8 but no scale found - just cast (lossy)
|
||||
data = data.to(compute_dtype)
|
||||
|
||||
if is_lora:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
params_to_sync.append((vllm_name, merged))
|
||||
else:
|
||||
# modules_to_save: send raw weight (no LoRA merge needed)
|
||||
params_to_sync.append((vllm_name, data.to(compute_dtype)))
|
||||
mod_path = vllm_name[: -len(".weight")]
|
||||
if mod_path in lora_info:
|
||||
A, B, s = lora_info[mod_path]
|
||||
merged = data.to(compute_dtype) + s * (
|
||||
B.to(compute_dtype) @ A.to(compute_dtype)
|
||||
)
|
||||
data = merged
|
||||
|
||||
# Batch sync only LoRA-modified params via HTTP+NCCL
|
||||
params_to_sync.append((vllm_name, data))
|
||||
|
||||
# Batch sync all params in one HTTP+NCCL call (vs individual calls)
|
||||
if params_to_sync:
|
||||
sync_mb = sum(t.numel() * t.element_size() for _, t in params_to_sync) / 1e6
|
||||
logger.info(
|
||||
f"Syncing {len(params_to_sync)} LoRA-modified params ({sync_mb:.0f} MB)"
|
||||
)
|
||||
vllm_client.batch_update_named_params(params_to_sync)
|
||||
|
||||
# Reset prefix cache after weight update
|
||||
@@ -1032,44 +949,26 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
import requests
|
||||
|
||||
vllm_client = self.vllm_generation.vllm_client
|
||||
base_url = vllm_client.base_url
|
||||
base_model = getattr(self.args, "model_name_or_path", "axolotl-lora")
|
||||
sync_timeout = getattr(self.args, "vllm_server_timeout", 300) or 300
|
||||
|
||||
# Try standard vLLM /v1/load_lora_adapter first, fall back to custom endpoint
|
||||
url = f"{vllm_client.base_url}/set_lora_adapter/"
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/load_lora_adapter",
|
||||
url,
|
||||
json={
|
||||
"lora_name": base_model,
|
||||
"lora_name": "active_lora",
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_path": adapter_path,
|
||||
"load_inplace": True,
|
||||
},
|
||||
timeout=sync_timeout,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
# Fallback: try custom /set_lora_adapter/ endpoint
|
||||
response = requests.post(
|
||||
f"{base_url}/set_lora_adapter/",
|
||||
json={
|
||||
"lora_name": "active_lora",
|
||||
"lora_int_id": self._lora_sync_version,
|
||||
"lora_path": adapter_path,
|
||||
},
|
||||
timeout=30,
|
||||
logger.warning(
|
||||
"Failed to set LoRA adapter: %s %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to set LoRA adapter: %s %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
return
|
||||
return
|
||||
|
||||
# Reset prefix cache after adapter update
|
||||
try:
|
||||
vllm_client.reset_prefix_cache()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to reset prefix cache: %s", exc)
|
||||
vllm_client.reset_prefix_cache()
|
||||
|
||||
# Clean up old adapter versions (keep only current)
|
||||
if self._lora_sync_version > 1:
|
||||
@@ -1109,11 +1008,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
if step == 0:
|
||||
logger.info("Skipping LoRA sync at step 0 (no training yet)")
|
||||
self._last_synced_step = step
|
||||
return
|
||||
# Native LoRA sync: save adapter to filesystem, vLLM loads it directly
|
||||
self._sync_lora_adapter()
|
||||
else:
|
||||
@@ -1189,7 +1088,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
# Background-thread generation (no scoring)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_single_turn(self, prompts, *args, **kwargs):
|
||||
def _generate_single_turn(self, prompts, **kwargs):
|
||||
"""Override to prevent weight sync from background thread and to use
|
||||
no-merge sync for PEFT models (FP8 models can't merge_adapter)."""
|
||||
is_bg = threading.current_thread() is not threading.main_thread()
|
||||
@@ -1222,7 +1121,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._patched_sync_weights = True
|
||||
|
||||
try:
|
||||
return super()._generate_single_turn(prompts, *args, **kwargs)
|
||||
return super()._generate_single_turn(prompts, **kwargs)
|
||||
finally:
|
||||
if saved_step is not None:
|
||||
self._last_loaded_step = saved_step
|
||||
@@ -1266,9 +1165,9 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
output = vg.vllm_client.chat(
|
||||
messages=unique_prompts,
|
||||
**sampling_params,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
tools=self.tools,
|
||||
chat_template=getattr(self, "chat_template", None),
|
||||
chat_template_kwargs=vg.chat_template_kwargs,
|
||||
tools=vg.tools,
|
||||
chat_template=vg.chat_template,
|
||||
)
|
||||
else:
|
||||
output = vg.vllm_client.generate(prompts=unique_prompts, **sampling_params)
|
||||
@@ -1556,29 +1455,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
) -> None:
|
||||
"""Called after advantages are computed. Override for replay buffer, re-roll, etc."""
|
||||
|
||||
def _notify_rollouts_scored(
|
||||
self,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Dispatch on_rollouts_scored to all registered plugins (rank 0 only)."""
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
pm = PluginManager.get_instance()
|
||||
if pm and pm.plugins:
|
||||
# Try _axolotl_cfg first (set by causal builder), fall back to
|
||||
# PluginManager's stored cfg (set during register phase).
|
||||
cfg = getattr(self, "_axolotl_cfg", None) or getattr(pm, "_cfg", None)
|
||||
if cfg is not None:
|
||||
pm.on_rollouts_scored(
|
||||
cfg, self, prompts, completions, rewards, advantages
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main-thread scoring
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1630,16 +1506,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
# --- Policy logprobs ---
|
||||
# When batch_flattening is enabled, use the flattened (padding-free) forward
|
||||
# pass for the scoring path. This removes padding tokens before the forward
|
||||
# pass, reducing attention FLOPs proportional to the padding ratio (20-34%
|
||||
# faster in benchmarks). Requires flash_attention_2 and no multimodal inputs.
|
||||
can_flatten = (
|
||||
getattr(self.args, "batch_flattening", False)
|
||||
and not forward_kwargs # no multimodal inputs
|
||||
and not self.is_fsdp_enabled # FSDP needs wrapped model
|
||||
)
|
||||
|
||||
logprob_batch_size = min(batch_size * 4, len(prompt_ids))
|
||||
with disable_gradient_checkpointing(
|
||||
self.model, self.args.gradient_checkpointing_kwargs
|
||||
@@ -1649,25 +1515,15 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self.use_vllm
|
||||
and getattr(self, "vllm_importance_sampling_correction", False)
|
||||
):
|
||||
if can_flatten:
|
||||
old_per_token_logps = self._get_per_token_logps_flattened(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=logprob_batch_size,
|
||||
prompt_mask=prompt_mask,
|
||||
)
|
||||
else:
|
||||
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
logprob_batch_size,
|
||||
num_images=num_images,
|
||||
**forward_kwargs,
|
||||
)
|
||||
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
logprob_batch_size,
|
||||
num_images=num_images,
|
||||
**forward_kwargs,
|
||||
)
|
||||
data["old_per_token_logps"] = old_per_token_logps
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
@@ -1728,12 +1584,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
logps_diff = per_token_logps_diff
|
||||
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
is_floor = 1.0 / is_cap # symmetric floor (e.g., cap=3.0 -> floor=0.333)
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
data["importance_sampling_ratio"] = is_ratio
|
||||
|
||||
# --- Collect rewards (launched before logprobs, should be done) ---
|
||||
@@ -1923,10 +1777,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
nanmax(self.accelerator.gather(torch.max(flat_isr))).item()
|
||||
)
|
||||
|
||||
# Log prompt/completion texts.
|
||||
# NB: gather_object merges per-rank local texts into a full-batch list
|
||||
# matching rewards_per_func and all_advantages which are already full-batch
|
||||
# tensors (gathered/computed earlier in this method). Lengths stay aligned.
|
||||
# Log prompt/completion texts
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=True
|
||||
)
|
||||
@@ -1934,25 +1785,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
if gather_object is not None:
|
||||
gathered_prompts = gather_object(prompts_text)
|
||||
gathered_completions = gather_object(completions_text)
|
||||
self._logs["prompt"].extend(gathered_prompts)
|
||||
self._logs["completion"].extend(gathered_completions)
|
||||
else:
|
||||
gathered_prompts = prompts_text
|
||||
gathered_completions = completions_text
|
||||
rewards_dict = {}
|
||||
self._logs["prompt"].extend(gather_object(prompts_text))
|
||||
self._logs["completion"].extend(gather_object(completions_text))
|
||||
for i, name in enumerate(self.reward_func_names):
|
||||
reward_list = rewards_per_func[:, i].tolist() # already full-batch
|
||||
self._logs["rewards"][name].extend(reward_list)
|
||||
rewards_dict[name] = reward_list
|
||||
adv_list = all_advantages.tolist() # already full-batch
|
||||
self._logs["advantages"].extend(adv_list)
|
||||
|
||||
# Notify plugins of scored rollouts
|
||||
self._notify_rollouts_scored(
|
||||
gathered_prompts, gathered_completions, rewards_dict, adv_list
|
||||
)
|
||||
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
||||
self._logs["advantages"].extend(all_advantages.tolist())
|
||||
|
||||
# Remove deferred keys
|
||||
for k in list(data.keys()):
|
||||
@@ -2028,11 +1865,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
# --- Policy logprobs for this chunk (GPU, overlaps with BG rewards) ---
|
||||
can_flatten = (
|
||||
getattr(self.args, "batch_flattening", False)
|
||||
and not forward_kwargs
|
||||
and not self.is_fsdp_enabled
|
||||
)
|
||||
logprob_batch_size = min(batch_size * 2, chunk_size)
|
||||
with disable_gradient_checkpointing(
|
||||
self.model, self.args.gradient_checkpointing_kwargs
|
||||
@@ -2042,25 +1874,15 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
self.use_vllm
|
||||
and getattr(self, "vllm_importance_sampling_correction", False)
|
||||
):
|
||||
if can_flatten:
|
||||
old_logps = self._get_per_token_logps_flattened(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=logprob_batch_size,
|
||||
prompt_mask=chunk_prompt_mask,
|
||||
)
|
||||
else:
|
||||
old_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
logprob_batch_size,
|
||||
num_images=num_images,
|
||||
**forward_kwargs,
|
||||
)
|
||||
old_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
logprob_batch_size,
|
||||
num_images=num_images,
|
||||
**forward_kwargs,
|
||||
)
|
||||
if "old_per_token_logps" not in data:
|
||||
total = len(data["prompt_ids"])
|
||||
data["old_per_token_logps"] = torch.zeros(
|
||||
@@ -2084,13 +1906,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
seq_is = is_mode in ("sequence_mask", "sequence_truncate")
|
||||
logps_diff = diff.sum(dim=-1, keepdim=True) if seq_is else diff
|
||||
is_ratio = torch.exp(logps_diff)
|
||||
# Symmetric floor clamp (matches non-streaming path at line ~1651)
|
||||
is_floor = 1.0 / is_cap
|
||||
if is_mode in ("sequence_truncate", "token_truncate"):
|
||||
is_ratio = torch.clamp(is_ratio, min=is_floor, max=is_cap)
|
||||
is_ratio = torch.clamp(is_ratio, max=is_cap)
|
||||
elif is_mode in ("sequence_mask", "token_mask"):
|
||||
is_ratio = is_ratio.masked_fill(is_ratio > is_cap, value=0.0)
|
||||
is_ratio = is_ratio.clamp(min=is_floor)
|
||||
if "importance_sampling_ratio" not in data:
|
||||
total = len(data["prompt_ids"])
|
||||
shape = (total, 1) if seq_is else (total, is_ratio.size(1))
|
||||
@@ -2409,38 +2228,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
return super()._prepare_inputs(generation_batch)
|
||||
|
||||
def _prepare_inputs_data_producer(self, generation_batch):
|
||||
"""Data producer path: produce rollout, score deferred logps, split into micro-batches.
|
||||
|
||||
Architecture (with async_prefetch=True):
|
||||
BG thread: produce(skip_policy_logps=True) → vLLM generation + reward computation
|
||||
Main thread: deferred scoring (policy logprobs via GPU forward pass) → training
|
||||
|
||||
Why deferred scoring is necessary for stable training:
|
||||
The policy logprobs (old_per_token_logps) must come from the CURRENT
|
||||
training model, not the vLLM model (which is N steps behind). Using
|
||||
stale vLLM logprobs as old_logps causes the importance sampling ratio
|
||||
to start far from 1.0, leading to:
|
||||
- Immediate PPO clipping → wasted samples
|
||||
- High-variance gradients from IS correction
|
||||
- Compounding per-token ratio errors on long sequences
|
||||
- In extreme cases, complete training failure (exp-003: accuracy=0)
|
||||
|
||||
Deferred scoring computes old_logps with the latest model weights, so
|
||||
the IS ratio starts at exactly 1.0 and drifts gradually — giving
|
||||
maximum useful gradient signal before clipping activates.
|
||||
|
||||
Cost: one additional forward pass per scoring round (GPU-bound, cannot
|
||||
overlap with training on the same GPU). Use ``batch_flattening: true``
|
||||
to reduce this cost by eliminating padding tokens from the forward pass.
|
||||
|
||||
Pipeline:
|
||||
[produce(BG)] → [deferred_scores(GPU)] → [train×GA(GPU)] → [weight_sync]
|
||||
↑ can't overlap with train (same GPU)
|
||||
|
||||
Bottleneck: the produce() wait (generation-limited) dominates when
|
||||
generation is slower than training + scoring. Async prefetch hides
|
||||
part of this by generating in the BG thread while training runs.
|
||||
"""
|
||||
"""Data producer path: produce rollout, score deferred logps, split into micro-batches."""
|
||||
# Return from buffer if available
|
||||
if self._buffered_inputs:
|
||||
return self._buffered_inputs.pop(0)
|
||||
@@ -2456,8 +2244,10 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
args=self.args,
|
||||
)
|
||||
|
||||
# Convert RolloutDataset back to a dict for scoring/splitting
|
||||
rollout = rollout_dataset._data
|
||||
|
||||
# If async (skip_policy_logps=True), score deferred logps on main thread
|
||||
if rollout.get("_pending_policy_logps"):
|
||||
if self.args.streaming_partial_batch:
|
||||
micro_batches = self._score_streaming(rollout)
|
||||
@@ -2469,6 +2259,7 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
|
||||
micro_batches = micro_batches * self.num_iterations
|
||||
else:
|
||||
# Sync path: data is already fully scored
|
||||
rollout = split_pixel_values_by_grid(rollout)
|
||||
batches = split_tensor_dict(rollout, self.args.steps_per_generation)
|
||||
micro_batches = [unsplit_pixel_values_by_grid(b) for b in batches]
|
||||
@@ -2489,10 +2280,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
rollout = future.result()
|
||||
self._submit_generation()
|
||||
|
||||
# With multi-process, only rank 0 generated. Broadcast to all ranks.
|
||||
if self.accelerator.num_processes > 1:
|
||||
rollout = self._broadcast_rollout(rollout)
|
||||
|
||||
if self.args.streaming_partial_batch:
|
||||
micro_batches = self._score_streaming(rollout)
|
||||
else:
|
||||
@@ -2511,219 +2298,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
return micro_batches[0]
|
||||
|
||||
def _get_per_token_logps_flattened(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=None,
|
||||
prompt_mask=None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute per-token log-probs using batch flattening (padding-free).
|
||||
|
||||
Instead of processing padded batches where attention wastes compute on
|
||||
padding tokens, this method:
|
||||
1. Chunks the batch into sub-batches of ``batch_size`` sequences
|
||||
2. For each chunk, flattens non-padding tokens into [1, chunk_tokens]
|
||||
3. Uses FlashAttentionKwargs (cu_seq_lens) for varlen attention
|
||||
4. Computes selective_log_softmax on the flat logits
|
||||
5. Gathers completion logprobs back to (B, logits_to_keep) padded format
|
||||
|
||||
Args:
|
||||
prompt_mask: (B, L) mask where 1 = prompt token, 0 = completion/padding.
|
||||
Used to determine the exact prompt length per sequence for correct
|
||||
logprob gathering. If None, inferred as seq_len - logits_to_keep.
|
||||
|
||||
Chunking prevents OOM when the total flattened sequence is too long
|
||||
(e.g., 32 sequences × 2048 tokens = 65K tokens → 20GB logits tensor).
|
||||
|
||||
Requires flash_attention_2 attention implementation.
|
||||
"""
|
||||
if not self.is_fsdp_enabled:
|
||||
model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False)
|
||||
|
||||
device = input_ids.device
|
||||
B, L = input_ids.shape
|
||||
if batch_size is None:
|
||||
batch_size = max(1, B)
|
||||
|
||||
autocast_ctx = torch.autocast(device_type=device.type, dtype=torch.bfloat16)
|
||||
all_logps = torch.zeros(B, logits_to_keep, device=device)
|
||||
|
||||
for chunk_start in range(0, B, batch_size):
|
||||
chunk_end = min(chunk_start + batch_size, B)
|
||||
chunk_ids = input_ids[chunk_start:chunk_end]
|
||||
chunk_mask = attention_mask[chunk_start:chunk_end]
|
||||
n = chunk_end - chunk_start
|
||||
|
||||
seq_lens = chunk_mask.sum(dim=1).to(torch.int32)
|
||||
total_tokens = seq_lens.sum().item()
|
||||
cu_seqlens = torch.zeros(n + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens[1:] = seq_lens.cumsum(0)
|
||||
|
||||
valid = chunk_mask.bool()
|
||||
flat_ids = chunk_ids[valid].unsqueeze(0)
|
||||
positions = torch.arange(L, device=device).unsqueeze(0).expand(n, L)
|
||||
flat_pos = positions[valid].unsqueeze(0)
|
||||
|
||||
with autocast_ctx:
|
||||
logits = model(
|
||||
input_ids=flat_ids,
|
||||
position_ids=flat_pos,
|
||||
use_cache=False,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=seq_lens.max().item(),
|
||||
max_length_k=seq_lens.max().item(),
|
||||
).logits
|
||||
logits = torch.nan_to_num(logits, nan=0.0)
|
||||
|
||||
# Compute logprobs on the flat shifted tensor
|
||||
flat_logits = logits[0, :-1, :] / self.temperature
|
||||
flat_targets = flat_ids[0, 1:]
|
||||
flat_logps = selective_log_softmax(
|
||||
flat_logits.unsqueeze(0), flat_targets.unsqueeze(0)
|
||||
)[0]
|
||||
|
||||
# Mask out cross-sequence boundary positions. In the shifted
|
||||
# tensor, position cu_seqlens[i]-1 (for i>0) is where sequence
|
||||
# i-1's last token "predicts" sequence i's first token — garbage.
|
||||
for boundary in cu_seqlens[1:-1]:
|
||||
idx = boundary.item() - 1
|
||||
if 0 <= idx < flat_logps.size(0):
|
||||
flat_logps[idx] = 0.0
|
||||
|
||||
# Gather completion logprobs per sequence.
|
||||
# Use prompt_mask to determine exact prompt length (not logits_to_keep,
|
||||
# which is the padded completion dimension and may exceed the actual
|
||||
# completion length for shorter sequences).
|
||||
for i in range(n):
|
||||
slen = seq_lens[i].item()
|
||||
abs_i = chunk_start + i # absolute index in the full batch
|
||||
if prompt_mask is not None:
|
||||
plen = int(prompt_mask[abs_i].sum().item())
|
||||
else:
|
||||
plen = max(1, slen - logits_to_keep)
|
||||
n_compl = slen - plen
|
||||
start = cu_seqlens[i].item() + plen - 1
|
||||
start = max(0, start)
|
||||
actual = min(n_compl, total_tokens - 1 - start)
|
||||
if actual > 0:
|
||||
all_logps[chunk_start + i, :actual] = flat_logps[
|
||||
start : start + actual
|
||||
]
|
||||
|
||||
del logits, flat_logits, flat_logps, flat_ids
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return all_logps
|
||||
|
||||
def _get_per_token_logps_and_entropies_flattened(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size=None,
|
||||
prompt_mask=None,
|
||||
compute_entropy=True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Flattened forward pass for training (with gradients).
|
||||
|
||||
Same padding removal as the scoring path, but:
|
||||
- Gradients flow through for backward pass
|
||||
- Computes entropy alongside logprobs
|
||||
- Per-sequence logprob/entropy extraction preserves grad graph
|
||||
"""
|
||||
device = input_ids.device
|
||||
B, L = input_ids.shape
|
||||
if batch_size is None:
|
||||
batch_size = max(1, B)
|
||||
|
||||
autocast_ctx = torch.autocast(device_type=device.type, dtype=torch.bfloat16)
|
||||
|
||||
# Pre-allocate output containers (will be filled with grad-carrying slices)
|
||||
all_logps_list: list[torch.Tensor] = []
|
||||
all_entropy_list: list[torch.Tensor] = []
|
||||
|
||||
for chunk_start in range(0, B, batch_size):
|
||||
chunk_end = min(chunk_start + batch_size, B)
|
||||
chunk_ids = input_ids[chunk_start:chunk_end]
|
||||
chunk_mask = attention_mask[chunk_start:chunk_end]
|
||||
n = chunk_end - chunk_start
|
||||
|
||||
seq_lens = chunk_mask.sum(dim=1).to(torch.int32)
|
||||
cu_seqlens = torch.zeros(n + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens[1:] = seq_lens.cumsum(0)
|
||||
|
||||
valid = chunk_mask.bool()
|
||||
flat_ids = chunk_ids[valid].unsqueeze(0)
|
||||
positions = torch.arange(L, device=device).unsqueeze(0).expand(n, L)
|
||||
flat_pos = positions[valid].unsqueeze(0)
|
||||
|
||||
with autocast_ctx:
|
||||
logits = model(
|
||||
input_ids=flat_ids,
|
||||
position_ids=flat_pos,
|
||||
use_cache=False,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=seq_lens.max().item(),
|
||||
max_length_k=seq_lens.max().item(),
|
||||
).logits
|
||||
logits = torch.nan_to_num(logits, nan=0.0)
|
||||
|
||||
# Extract logprobs and entropy per-sequence (avoids cross-sequence targets,
|
||||
# preserves gradient graph through selective_log_softmax → logits → model)
|
||||
for i in range(n):
|
||||
slen = seq_lens[i].item()
|
||||
abs_i = chunk_start + i
|
||||
if prompt_mask is not None:
|
||||
plen = int(prompt_mask[abs_i].sum().item())
|
||||
else:
|
||||
plen = max(1, slen - logits_to_keep)
|
||||
n_compl = slen - plen
|
||||
s = cu_seqlens[i].item()
|
||||
|
||||
if n_compl <= 0:
|
||||
# No completion tokens — append zeros
|
||||
all_logps_list.append(torch.zeros(logits_to_keep, device=device))
|
||||
if compute_entropy:
|
||||
all_entropy_list.append(
|
||||
torch.zeros(logits_to_keep, device=device)
|
||||
)
|
||||
continue
|
||||
|
||||
with autocast_ctx:
|
||||
# Shifted logits and targets for this sequence only
|
||||
seq_logits = logits[0, s + plen - 1 : s + slen - 1, :]
|
||||
seq_logits = seq_logits / self.temperature
|
||||
seq_targets = flat_ids[0, s + plen : s + slen]
|
||||
|
||||
# Log probs (differentiable)
|
||||
lps = selective_log_softmax(
|
||||
seq_logits.unsqueeze(0), seq_targets.unsqueeze(0)
|
||||
)[0] # (n_compl,)
|
||||
|
||||
# Pad to logits_to_keep
|
||||
if n_compl < logits_to_keep:
|
||||
lps = F.pad(lps, (0, logits_to_keep - n_compl))
|
||||
all_logps_list.append(lps[:logits_to_keep])
|
||||
|
||||
if compute_entropy:
|
||||
ent = entropy_from_logits(seq_logits) # (n_compl,)
|
||||
if n_compl < logits_to_keep:
|
||||
ent = F.pad(ent, (0, logits_to_keep - n_compl))
|
||||
all_entropy_list.append(ent[:logits_to_keep])
|
||||
|
||||
# Stack per-sequence results into (B, logits_to_keep) tensors
|
||||
all_logps = torch.stack(all_logps_list, dim=0)
|
||||
all_entropies = (
|
||||
torch.stack(all_entropy_list, dim=0) if compute_entropy else None
|
||||
)
|
||||
return all_logps, all_entropies
|
||||
|
||||
@profiling_decorator
|
||||
def _get_per_token_logps_and_entropies(
|
||||
self,
|
||||
@@ -2839,9 +2413,6 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
logits, completion_ids, self.temperature
|
||||
)
|
||||
all_logps.append(logps)
|
||||
# Liger fused path doesn't compute entropy — append zeros
|
||||
if compute_entropy:
|
||||
all_entropies.append(torch.zeros_like(logps))
|
||||
else:
|
||||
logits = logits[:, :-1, :]
|
||||
logits = logits[:, -logits_to_keep:, :]
|
||||
@@ -2895,47 +2466,20 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
else completion_mask * inputs["tool_mask"]
|
||||
)
|
||||
|
||||
# Check for multimodal inputs
|
||||
forward_kwargs = {
|
||||
k: inputs[k]
|
||||
for k in (
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"num_images",
|
||||
"pixel_attention_mask",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
if k in inputs and inputs[k] is not None
|
||||
}
|
||||
|
||||
can_flatten = (
|
||||
getattr(self.args, "batch_flattening", False)
|
||||
and not forward_kwargs
|
||||
and not self.is_fsdp_enabled
|
||||
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
compute_entropy=True,
|
||||
pixel_values=inputs.get("pixel_values"),
|
||||
image_grid_thw=inputs.get("image_grid_thw"),
|
||||
num_images=inputs.get("num_images"),
|
||||
pixel_attention_mask=inputs.get("pixel_attention_mask"),
|
||||
image_sizes=inputs.get("image_sizes"),
|
||||
token_type_ids=inputs.get("token_type_ids"),
|
||||
mm_token_type_ids=inputs.get("mm_token_type_ids"),
|
||||
)
|
||||
|
||||
if can_flatten:
|
||||
per_token_logps, entropies = (
|
||||
self._get_per_token_logps_and_entropies_flattened(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
prompt_mask=prompt_mask,
|
||||
compute_entropy=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
compute_entropy=True,
|
||||
**forward_kwargs,
|
||||
)
|
||||
if self.top_entropy_quantile < 1.0:
|
||||
entropy_mask = self.get_high_entropy_mask(
|
||||
entropies, mask, 1 - self.top_entropy_quantile
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||
|
||||
@@ -242,30 +242,6 @@ class BasePlugin:
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_rollouts_scored(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
trainer,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Called after rollouts are scored during online RL (GRPO/PPO).
|
||||
|
||||
Provides access to the full scored rollout data for logging, trace
|
||||
storage, or analysis. Called once per scoring step with all samples
|
||||
from that step.
|
||||
|
||||
Args:
|
||||
cfg: The axolotl configuration.
|
||||
trainer: The trainer instance.
|
||||
prompts: List of prompt texts (one per sample).
|
||||
completions: List of completion texts (one per sample).
|
||||
rewards: Dict mapping reward function name to list of reward values.
|
||||
advantages: List of advantage values (one per sample).
|
||||
"""
|
||||
|
||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
"""Performs actions after training is complete.
|
||||
|
||||
@@ -637,36 +613,6 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train(cfg, model)
|
||||
|
||||
def on_rollouts_scored(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
trainer,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
):
|
||||
"""Calls the on_rollouts_scored method of all registered plugins.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugins.
|
||||
trainer: The trainer instance.
|
||||
prompts: List of prompt texts.
|
||||
completions: List of completion texts.
|
||||
rewards: Dict mapping reward function name to list of rewards.
|
||||
advantages: List of advantage values.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
try:
|
||||
plugin.on_rollouts_scored(
|
||||
cfg, trainer, prompts, completions, rewards, advantages
|
||||
)
|
||||
except Exception:
|
||||
LOG.warning(
|
||||
f"Plugin {plugin.__class__.__name__}.on_rollouts_scored failed",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def post_train_unload(self, cfg: DictDefault):
|
||||
"""Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
|
||||
@@ -36,9 +36,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
"""Generate samples at specified intervals."""
|
||||
if (
|
||||
state.global_step > 0
|
||||
and state.global_step
|
||||
% self.trainer.axolotl_cfg.diffusion.generation_interval
|
||||
== 0
|
||||
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
|
||||
):
|
||||
if not self.trainer.state.is_world_process_zero:
|
||||
return
|
||||
@@ -54,7 +52,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
|
||||
# Generate samples
|
||||
diffusion_cfg = self.trainer.axolotl_cfg.diffusion
|
||||
diffusion_cfg = self.trainer.cfg.diffusion
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
@@ -144,11 +142,11 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.axolotl_cfg.use_wandb:
|
||||
if wandb.run is not None: # type: ignore[attr-defined]
|
||||
wandb.log( # type: ignore[attr-defined]
|
||||
if self.trainer.cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"generated_samples": wandb.Table( # type: ignore[attr-defined]
|
||||
"generated_samples": wandb.Table(
|
||||
columns=[
|
||||
"step",
|
||||
"original",
|
||||
|
||||
@@ -38,6 +38,4 @@ class DiffusionPlugin(BasePlugin):
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
||||
"""Configure trainer after creation."""
|
||||
if hasattr(trainer, "axolotl_cfg"):
|
||||
trainer.axolotl_cfg = cfg
|
||||
trainer.post_set_axolotl_cfg()
|
||||
trainer.set_config(cfg)
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
@@ -20,17 +21,19 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cfg = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def post_set_axolotl_cfg(self):
|
||||
def set_config(self, config: DictDefault):
|
||||
"""Set config for diffusion training."""
|
||||
self.cfg = config
|
||||
self._cache_special_token_ids()
|
||||
self._resolve_mask_token_id()
|
||||
|
||||
token_id = int(getattr(self.axolotl_cfg.diffusion, "mask_token_id", 0))
|
||||
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
|
||||
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
||||
|
||||
if getattr(self.axolotl_cfg.diffusion, "generate_samples", True):
|
||||
if getattr(config.diffusion, "generate_samples", True):
|
||||
generation_callback = DiffusionGenerationCallback(self)
|
||||
self.add_callback(generation_callback)
|
||||
|
||||
@@ -38,20 +41,18 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
"""Ensure mask_token_id is valid for the current tokenizer."""
|
||||
from .utils import resolve_mask_token_id
|
||||
|
||||
assert self.axolotl_cfg is not None, "axolotl_cfg is not set yet"
|
||||
|
||||
tokenizer = getattr(self, "processing_class", None)
|
||||
if tokenizer is None:
|
||||
return
|
||||
|
||||
mid = resolve_mask_token_id(
|
||||
tokenizer,
|
||||
self.axolotl_cfg,
|
||||
self.cfg,
|
||||
allow_add=True,
|
||||
model=getattr(self, "model", None),
|
||||
)
|
||||
try:
|
||||
self.axolotl_cfg.diffusion.mask_token_id = int(mid)
|
||||
self.cfg.diffusion.mask_token_id = int(mid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -149,7 +150,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
masked_indices = masked_indices & answer_mask
|
||||
|
||||
# Create masked input
|
||||
mask_token_id = int(self.axolotl_cfg.diffusion.mask_token_id)
|
||||
mask_token_id = int(self.cfg.diffusion.mask_token_id)
|
||||
mask_value = torch.full_like(input_ids, mask_token_id)
|
||||
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
||||
|
||||
@@ -193,12 +194,12 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||
input_ids, attention_mask, labels, self.axolotl_cfg.diffusion.eps
|
||||
input_ids, attention_mask, labels, self.cfg.diffusion.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask, sample_packing=self.axolotl_cfg.sample_packing
|
||||
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
@@ -221,7 +222,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
masked_logits.float(), masked_targets, reduction="none"
|
||||
)
|
||||
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
@@ -250,7 +251,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
||||
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
||||
# for stable scaling across varying mask ratios.
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
loss = weighted_loss.sum() / (
|
||||
input_ids.shape[0] * input_ids.shape[1]
|
||||
)
|
||||
@@ -282,7 +283,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
}
|
||||
|
||||
# If doing SFT training, log answer-specific metrics
|
||||
if self.axolotl_cfg.datasets is not None:
|
||||
if self.cfg.datasets is not None:
|
||||
with torch.no_grad():
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
||||
@@ -291,7 +292,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
||||
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
||||
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||
|
||||
@@ -28,7 +28,7 @@ use_scattermoe: true
|
||||
use_sonicmoe: true
|
||||
```
|
||||
|
||||
**Important:** Setting `experts_implementation` to `batched_mm` or `grouped_mm` is incompatible with custom kernel options. The exception is `experts_implementation: scattermoe`, which is used for models like Gemma 4 that embed MoE directly in the decoder layer (no SparseMoeBlock) and dispatch through the transformers `ExpertsInterface`.
|
||||
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
|
||||
|
||||
### SonicMoE installation
|
||||
|
||||
@@ -52,110 +52,26 @@ The `KernelsPlugin` runs before model loading and:
|
||||
|
||||
### ScatterMoE
|
||||
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||
|
||||
### SonicMoE
|
||||
1. Resolves the model's MoE block class(es) from `constants.py`.
|
||||
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports pluggable routing strategies (see routing table below).
|
||||
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports both softmax->topk and sigmoid->topk routing strategies.
|
||||
|
||||
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
|
||||
|
||||
## Model Support Matrix
|
||||
#### Supported Models
|
||||
|
||||
Most models use the **SwiGLU** activation (`silu(gate) * up`). Gemma 4 uses **GEGLU** (`gelu(gate) * up`). ScatterMoE supports any gated activation (activation is applied in Python between kernel calls). SonicMoE supports SwiGLU, GEGLU, and REGLU via its `ActivationType` enum.
|
||||
|
||||
### Routing strategies
|
||||
|
||||
| Routing Strategy | Description | ScatterMoE | SonicMoE |
|
||||
|---|---|:---:|:---:|
|
||||
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
|
||||
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
|
||||
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
|
||||
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
|
||||
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
|
||||
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
|
||||
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
|
||||
| softmax → topk + per_expert_scale | RMSNorm → scale → proj → softmax → topk → renorm → per-expert learned scales | Yes | Yes |
|
||||
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
|
||||
|
||||
### Per-model support
|
||||
|
||||
| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|
||||
|---|---|---|:---:|:---:|
|
||||
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
|
||||
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
|
||||
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
|
||||
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
|
||||
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
|
||||
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
|
||||
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
|
||||
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
|
||||
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
|
||||
| `gemma4_text` | Gemma 4 (26B-A4B) | softmax → topk + per_expert_scale | **Yes**\*\* | **Yes**\*\* |
|
||||
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
|
||||
|
||||
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
|
||||
|
||||
\*\* Gemma 4 uses `experts_implementation: scattermoe` path (registered via `ExpertsInterface`) instead of SparseMoeBlock patching, since Gemma 4 embeds MoE directly in its decoder layer (no separate SparseMoeBlock). See the [Gemma 4 section](#gemma-4) below.
|
||||
|
||||
### Feature comparison
|
||||
|
||||
| Feature | ScatterMoE | SonicMoE |
|
||||
|---|:---:|:---:|
|
||||
| Kernel backend | Triton | CUTLASS |
|
||||
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
|
||||
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
|
||||
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
|
||||
| Gate/router LoRA | Yes | Yes |
|
||||
| Expert LoRA | Yes (fused) | Yes (materialized) |
|
||||
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
|
||||
| Selective expert dequantization | Yes (~97% memory savings) | No |
|
||||
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
|
||||
| torch.compile routing | No | Yes (optional) |
|
||||
|
||||
## Shared Expert Handling
|
||||
|
||||
Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:
|
||||
|
||||
1. `shared_expert` (Qwen2-MoE)
|
||||
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
|
||||
3. `shared_mlp` (HunYuan V1 MoE)
|
||||
|
||||
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
|
||||
|
||||
## Gemma 4
|
||||
|
||||
Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
|
||||
|
||||
- **No SparseMoeBlock**: MoE is embedded directly in the decoder layer alongside a dense MLP. Both run in parallel and their outputs are summed.
|
||||
- **Custom router** (`Gemma4TextRouter`): RMSNorm → learned scale → linear projection → softmax → top-k → renormalization → per-expert learned scales.
|
||||
- **GEGLU activation**: Uses `gelu_pytorch_tanh` (not SiLU/SwiGLU like most other MoE models).
|
||||
- **128 experts, top-k=8** for the 26B-A4B variant.
|
||||
|
||||
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
||||
|
||||
**Important limitations:**
|
||||
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
|
||||
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
|
||||
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
|
||||
|
||||
## Limitations
|
||||
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
|
||||
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
|
||||
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
|
||||
|
||||
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
|
||||
@@ -34,44 +34,14 @@ class KernelsArgs(BaseModel):
|
||||
@classmethod
|
||||
def check_experts_implementation(cls, data):
|
||||
experts_implementation = data.get("experts_implementation")
|
||||
use_scattermoe = data.get("use_scattermoe", False)
|
||||
if experts_implementation is None:
|
||||
# transformers may default to batched_mm when unset
|
||||
data["experts_implementation"] = "eager"
|
||||
elif experts_implementation == "scattermoe" and not use_scattermoe:
|
||||
elif experts_implementation != "eager":
|
||||
LOG.warning(
|
||||
"`experts_implementation='scattermoe'` requires `use_scattermoe: true`. "
|
||||
"Automatically setting to 'eager'."
|
||||
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
|
||||
)
|
||||
data["experts_implementation"] = "eager"
|
||||
elif experts_implementation not in ("eager", "scattermoe"):
|
||||
LOG.warning(
|
||||
f"`experts_implementation={experts_implementation!r}` is not compatible with "
|
||||
f"custom MoE kernels. Automatically setting to 'eager'."
|
||||
)
|
||||
data["experts_implementation"] = "eager"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_sonicmoe_lora_overhead(cls, data):
|
||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
||||
"lora",
|
||||
"qlora",
|
||||
):
|
||||
lora_target = data.get("lora_target_modules") or []
|
||||
lora_linear = data.get("lora_target_linear_modules") or []
|
||||
targets = (
|
||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
||||
if any(kw in t for t in targets for kw in expert_keywords):
|
||||
LOG.info(
|
||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -6,12 +6,6 @@ Used by both ScatterMoE and SonicMoE kernel paths.
|
||||
|
||||
Values can be a single class name (str) or a list of class names for models
|
||||
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
|
||||
|
||||
Models with custom routing (see sonicmoe/routing.py for implementations):
|
||||
- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing)
|
||||
- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing)
|
||||
- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing)
|
||||
- gemma4_text: RMSNorm→scale→proj→softmax→topk→renorm→per_expert_scale (experts-level patch)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -42,61 +36,16 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# softmax->topk, e_score_correction_bias between softmax and topk
|
||||
"ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock",
|
||||
# softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
"deepseek_v2": "DeepseekV2Moe",
|
||||
# softmax->topk, gate.wg (not gate.weight)
|
||||
"hunyuan_v1_moe": "HunYuanMoEV1Moe",
|
||||
# TODO: gpt_oss deferred — transposed weight layout [E,H,2*I], expert biases,
|
||||
# and custom GLU activation require a dedicated forward path in patch.py.
|
||||
# "gpt_oss": "GptOssMLP",
|
||||
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
|
||||
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
|
||||
}
|
||||
|
||||
|
||||
# Models where MoE is NOT in a separate SparseMoeBlock but embedded in the
|
||||
# decoder layer. For these, we patch the Experts class forward directly
|
||||
# (same signature: hidden_states, top_k_index, top_k_weights -> Tensor).
|
||||
# Routing stays untouched — the original model router runs as-is.
|
||||
EXPERTS_ONLY_BLOCK = {
|
||||
# gemma4: hybrid MLP+MoE in decoder layer, custom Gemma4TextRouter,
|
||||
# no SparseMoeBlock. Experts use @use_experts_implementation with
|
||||
# standard 3D param layout (gate_up_proj [E, 2*I, H], down_proj [E, H, I]).
|
||||
"gemma4_text": "Gemma4TextExperts",
|
||||
}
|
||||
|
||||
|
||||
def resolve_experts_class(model_type: str):
|
||||
"""Resolve the Experts class for models that need experts-level patching.
|
||||
|
||||
Returns the class, or None if the model uses SparseMoeBlock-level patching.
|
||||
"""
|
||||
entry = EXPERTS_ONLY_BLOCK.get(model_type)
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError:
|
||||
if model_type.endswith("_text"):
|
||||
parent_type = model_type.removesuffix("_text")
|
||||
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
|
||||
module = importlib.import_module(module_path)
|
||||
else:
|
||||
raise
|
||||
|
||||
cls = getattr(module, entry, None)
|
||||
if cls is None:
|
||||
raise ValueError(f"Could not find class '{entry}' in '{module_path}'")
|
||||
return cls
|
||||
|
||||
|
||||
def is_experts_only_model(model_type: str) -> bool:
|
||||
"""Check if a model type requires experts-level (not block-level) patching."""
|
||||
return model_type in EXPERTS_ONLY_BLOCK
|
||||
|
||||
|
||||
def resolve_moe_block_classes(model_type: str):
|
||||
"""Resolve all MoE block classes from transformers for the given model type.
|
||||
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
"""
|
||||
ScatterMoE-accelerated experts forward for Gemma4.
|
||||
|
||||
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
|
||||
The decoder layer handles routing (Gemma4TextRouter) and calls
|
||||
``experts(hidden_states, top_k_index, top_k_weights)`` directly.
|
||||
|
||||
This module registers a ``"scattermoe"`` implementation in the transformers
|
||||
``ExpertsInterface``, which the ``@use_experts_implementation`` decorator
|
||||
dispatches to when ``config._experts_implementation == "scattermoe"``.
|
||||
|
||||
This is the clean way to hook into transformers' MoE dispatch — no
|
||||
monkeypatching required. Works for Gemma4 and any future model that uses
|
||||
``@use_experts_implementation`` with the standard forward signature
|
||||
``(hidden_states, top_k_index, top_k_weights) -> Tensor``.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||
|
||||
|
||||
def _has_peft_wrapper(module):
|
||||
"""Check if a module's parameter has been wrapped by PEFT ParamWrapper."""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
|
||||
for attr in ("gate_up_proj", "down_proj"):
|
||||
param = getattr(module, attr, None)
|
||||
if isinstance(param, ParamWrapper):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_experts_lora(experts):
|
||||
"""Extract base weights and LoRA params from a PEFT-wrapped Experts module.
|
||||
|
||||
Returns:
|
||||
(base_experts, gup_lora, down_lora) where each lora is
|
||||
(lora_A, lora_B, scaling) or None.
|
||||
"""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
except ImportError:
|
||||
return experts, None, None
|
||||
|
||||
if not isinstance(getattr(experts, "gate_up_proj", None), ParamWrapper):
|
||||
return experts, None, None
|
||||
|
||||
base_experts = experts
|
||||
gup_lora = None
|
||||
down_lora = None
|
||||
|
||||
gup_param = experts.gate_up_proj
|
||||
if isinstance(gup_param, ParamWrapper):
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_param)
|
||||
if lora_A is not None:
|
||||
num_experts = experts.num_experts
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
from .layers import peft_lora_to_scattermoe
|
||||
|
||||
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
|
||||
gup_lora = (sm_A, sm_B, scaling)
|
||||
|
||||
down_param = experts.down_proj
|
||||
if isinstance(down_param, ParamWrapper):
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_param)
|
||||
if lora_A is not None:
|
||||
num_experts = experts.num_experts
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
from .layers import peft_lora_to_scattermoe
|
||||
|
||||
sm_A, sm_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
|
||||
down_lora = (sm_A, sm_B, scaling)
|
||||
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
def _get_base_param(param):
|
||||
"""Get the base tensor from a PEFT ParamWrapper or regular Parameter."""
|
||||
try:
|
||||
from peft.tuners.param_wrapper import ParamWrapper
|
||||
|
||||
while isinstance(param, ParamWrapper):
|
||||
param = param.original_parameter
|
||||
except ImportError:
|
||||
pass
|
||||
return param
|
||||
|
||||
|
||||
def _parallel_linear_maybe_lora(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_tuple,
|
||||
grouped_in,
|
||||
grouped_out,
|
||||
gates=None,
|
||||
):
|
||||
"""Call parallel_linear or parallel_linear_lora depending on whether LoRA is active."""
|
||||
if lora_tuple is not None:
|
||||
lora_A, lora_B, scaling = lora_tuple
|
||||
return parallel_linear_lora(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
gates=gates,
|
||||
)
|
||||
return parallel_linear(
|
||||
x,
|
||||
weight,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
gates=gates,
|
||||
)
|
||||
|
||||
|
||||
def scattermoe_experts_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""ScatterMoE-accelerated experts forward.
|
||||
|
||||
Drop-in replacement for the standard Experts forward signature used by
|
||||
``@use_experts_implementation``-decorated classes (Gemma4, Mixtral, etc.):
|
||||
``(hidden_states [T, H], top_k_index [T, K], top_k_weights [T, K]) -> [T, H]``
|
||||
"""
|
||||
K = top_k_index.shape[1]
|
||||
|
||||
routing_weights = top_k_weights.to(hidden_states.dtype)
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
top_k_index, num_experts=self.num_experts
|
||||
)
|
||||
|
||||
# Get base weights (unwrap PEFT if needed)
|
||||
gate_up_weight = _get_base_param(self.gate_up_proj).transpose(2, 1)
|
||||
down_weight = _get_base_param(self.down_proj).transpose(2, 1)
|
||||
|
||||
# Extract LoRA params if PEFT is active
|
||||
gup_lora, down_lora = None, None
|
||||
if _has_peft_wrapper(self):
|
||||
_, gup_lora, down_lora = _unwrap_experts_lora(self)
|
||||
|
||||
# Gate-up projection (with optional LoRA)
|
||||
gates_h = _parallel_linear_maybe_lora(
|
||||
hidden_states,
|
||||
gate_up_weight,
|
||||
K,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gup_lora,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
gates, h = gates_h.chunk(2, dim=-1)
|
||||
h = self.act_fn(gates) * h
|
||||
|
||||
# Down projection (with optional LoRA + routing weights)
|
||||
output = _parallel_linear_maybe_lora(
|
||||
h,
|
||||
down_weight,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
down_lora,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def register_scattermoe_experts():
|
||||
"""Register ``"scattermoe"`` in the transformers ExpertsInterface.
|
||||
|
||||
After calling this, any model with ``@use_experts_implementation`` will
|
||||
dispatch to ScatterMoE when ``config._experts_implementation == "scattermoe"``.
|
||||
|
||||
Also patches ``get_correct_experts_implementation`` to accept ``"scattermoe"``
|
||||
as a valid value (transformers hardcodes an allowlist).
|
||||
"""
|
||||
from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
# 1. Register the forward function in the global interface
|
||||
ALL_EXPERTS_FUNCTIONS.register("scattermoe", scattermoe_experts_forward)
|
||||
|
||||
# 2. Patch the validation to accept "scattermoe"
|
||||
_original_get_correct = PreTrainedModel.get_correct_experts_implementation
|
||||
|
||||
def _patched_get_correct(self_model, requested_experts: str | None) -> str:
|
||||
if requested_experts == "scattermoe":
|
||||
return "scattermoe"
|
||||
return _original_get_correct(self_model, requested_experts)
|
||||
|
||||
PreTrainedModel.get_correct_experts_implementation = _patched_get_correct
|
||||
|
||||
|
||||
# Legacy monkeypatch approach (kept for backward compat with existing tests)
|
||||
def patch_gemma4_scattermoe():
|
||||
"""Monkeypatch Gemma4TextExperts.forward with ScatterMoE kernel."""
|
||||
from axolotl.integrations.kernels.constants import resolve_experts_class
|
||||
|
||||
experts_cls = resolve_experts_class("gemma4_text")
|
||||
if experts_cls is None:
|
||||
raise ValueError("Could not resolve Gemma4TextExperts class")
|
||||
|
||||
if hasattr(experts_cls, "_original_forward"):
|
||||
return # already patched
|
||||
|
||||
experts_cls._original_forward = experts_cls.forward
|
||||
experts_cls.forward = scattermoe_experts_forward
|
||||
@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||
instead of ``gate_up_proj``.
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
for attr in ("gate_up_proj", "up_proj"):
|
||||
param = getattr(base_experts, attr, None)
|
||||
if param is not None:
|
||||
num_experts = param.shape[0]
|
||||
break
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
Supports:
|
||||
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||
# ====================================================================
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||
# ====================================================================
|
||||
is_gated = hasattr(experts, "gate_up_proj")
|
||||
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||
# ====================================================================
|
||||
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||
|
||||
expert_input = hidden_states_flat
|
||||
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
and up_proj_attr in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
up_W = selective_expert_weights(
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
up_proj_attr,
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
).transpose(2, 1)
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||
# ====================================================================
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear_lora(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
# ====================================================================
|
||||
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||
# ====================================================================
|
||||
if is_gated:
|
||||
gates, h = up_out.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
else:
|
||||
h = experts.act_fn(up_out)
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection back to hidden_size (NemotronH)
|
||||
# ====================================================================
|
||||
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||
expert_output = fc2_latent_proj(expert_output)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
|
||||
@@ -49,11 +49,6 @@ class ParallelLinear(torch.autograd.Function):
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
output = kernels.ops.scatter2scatter(
|
||||
X=x,
|
||||
|
||||
@@ -65,11 +65,6 @@ class ScatterMoELoRA(torch.autograd.Function):
|
||||
use_fused_dX: bool = False,
|
||||
use_fused_gather: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
output = scatter2scatter_lora(
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
"""
|
||||
SonicMoE-accelerated experts forward for Gemma4.
|
||||
|
||||
Gemma4 has no separate SparseMoeBlock — MoE is embedded in the decoder layer.
|
||||
This module provides a drop-in replacement for ``Gemma4TextExperts.forward``
|
||||
that uses SonicMoE kernels while preserving the original call signature.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .lora import has_lora, materialize_expert_lora, unwrap_experts_lora
|
||||
|
||||
|
||||
def _get_expert_weights_gemma4(experts_module):
|
||||
"""Extract expert weights from Gemma4TextExperts, applying LoRA if active.
|
||||
|
||||
Returns:
|
||||
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
|
||||
"""
|
||||
if has_lora(experts_module):
|
||||
base_experts, lora_dict = unwrap_experts_lora(experts_module)
|
||||
gate_up = materialize_expert_lora(
|
||||
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
|
||||
)
|
||||
down = materialize_expert_lora(
|
||||
base_experts.down_proj, lora_dict.get("down_proj")
|
||||
)
|
||||
else:
|
||||
gate_up = experts_module.gate_up_proj
|
||||
down = experts_module.down_proj
|
||||
|
||||
# Permute to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
|
||||
|
||||
|
||||
def gemma4_sonicmoe_experts_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""SonicMoE-accelerated replacement for Gemma4TextExperts.forward.
|
||||
|
||||
Same signature as the original: (hidden_states [T, H], top_k_index [T, K],
|
||||
top_k_weights [T, K]) -> output [T, H].
|
||||
"""
|
||||
from sonicmoe import moe_general_routing_inputs
|
||||
from sonicmoe.enums import ActivationType
|
||||
|
||||
T, _ = hidden_states.shape
|
||||
K = top_k_index.shape[1]
|
||||
E = self.num_experts
|
||||
|
||||
# Convert routing outputs to SonicMoE's flat format
|
||||
# Token indices sorted ascending (required by SonicMoE)
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
flat_scores = top_k_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_k_index.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
# Get weights (with LoRA materialization if needed)
|
||||
gate_up_weight, down_weight = _get_expert_weights_gemma4(self)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states.dtype)
|
||||
down_weight = down_weight.to(hidden_states.dtype)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("SonicMoE requires CUDA. No CUDA device available.")
|
||||
cuda_stream = torch.cuda.current_stream().cuda_stream
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
hidden_states,
|
||||
flat_scores,
|
||||
flat_token_idx,
|
||||
flat_expert_idx,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
E,
|
||||
cuda_stream,
|
||||
ActivationType.GEGLU,
|
||||
False, # is_inference_mode
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def patch_gemma4_sonicmoe():
|
||||
"""Monkeypatch Gemma4TextExperts.forward with SonicMoE kernel."""
|
||||
from axolotl.integrations.kernels.constants import resolve_experts_class
|
||||
|
||||
experts_cls = resolve_experts_class("gemma4_text")
|
||||
if experts_cls is None:
|
||||
raise ValueError("Could not resolve Gemma4TextExperts class")
|
||||
|
||||
if hasattr(experts_cls, "_original_forward"):
|
||||
return # already patched
|
||||
|
||||
experts_cls._original_forward = experts_cls.forward
|
||||
experts_cls.forward = gemma4_sonicmoe_experts_forward
|
||||
@@ -1,220 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
SonicMoE LoRA support via runtime weight materialization.
|
||||
|
||||
SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA.
|
||||
Instead, we materialize the effective weight W_eff = W + scaling * (B @ A)
|
||||
before each CUTLASS call, and use a custom autograd.Function to route
|
||||
gradients back to the LoRA A and B parameters.
|
||||
|
||||
PEFT unwrapping utilities are also provided to handle the ParamWrapper
|
||||
chain that PEFT creates when targeting expert parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# =============================================================================
|
||||
# PEFT unwrapping utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def has_lora(module) -> bool:
|
||||
"""Check if a module is wrapped by PEFT with LoRA."""
|
||||
return hasattr(module, "base_layer") and hasattr(module, "lora_A")
|
||||
|
||||
|
||||
def get_lora_params_from_wrapper(module) -> tuple:
|
||||
"""Extract LoRA parameters from a PEFT ParamWrapper.
|
||||
|
||||
Returns:
|
||||
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
|
||||
"""
|
||||
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
|
||||
return None, None, None
|
||||
|
||||
active_adapters = getattr(module, "active_adapters", ["default"])
|
||||
if not active_adapters:
|
||||
return None, None, None
|
||||
|
||||
adapter_name = active_adapters[0]
|
||||
|
||||
lora_A_dict = getattr(module, "lora_A", {})
|
||||
lora_B_dict = getattr(module, "lora_B", {})
|
||||
scaling_dict = getattr(module, "scaling", {})
|
||||
|
||||
if (
|
||||
adapter_name not in lora_A_dict
|
||||
or adapter_name not in lora_B_dict
|
||||
or adapter_name not in scaling_dict
|
||||
):
|
||||
return None, None, None
|
||||
|
||||
lora_A = lora_A_dict[adapter_name].weight
|
||||
lora_B = lora_B_dict[adapter_name].weight
|
||||
scaling = scaling_dict[adapter_name]
|
||||
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
def unwrap_gate_lora(gate_module):
|
||||
"""Unwrap PEFT ParamWrapper on the router gate.
|
||||
|
||||
When PEFT targets ``gate.weight``, ``self.gate`` becomes::
|
||||
|
||||
ParamWrapper(weight)
|
||||
-> base_layer: Router (the real module)
|
||||
|
||||
Returns:
|
||||
(base_gate, gate_weight, gate_lora_delta_or_None)
|
||||
|
||||
``base_gate`` is the original router module (with ``.top_k``, etc.).
|
||||
``gate_weight`` is the base router weight tensor.
|
||||
``gate_lora_delta_or_None`` is the LoRA delta if active, else None.
|
||||
Kept separate to avoid mixing DTensor + Tensor under FSDP.
|
||||
"""
|
||||
if has_lora(gate_module):
|
||||
base_gate = gate_module.base_layer
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
|
||||
if lora_A is not None:
|
||||
delta = scaling * (lora_B @ lora_A)
|
||||
return base_gate, base_gate.weight, delta
|
||||
return base_gate, base_gate.weight, None
|
||||
|
||||
return gate_module, gate_module.weight, None
|
||||
|
||||
|
||||
def unwrap_experts_lora(experts_module):
|
||||
"""Walk a PEFT ParamWrapper chain on ``self.experts``.
|
||||
|
||||
When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj``
|
||||
via ``target_parameters``, ``self.experts`` becomes::
|
||||
|
||||
ParamWrapper(down_proj)
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: Experts (the real module)
|
||||
|
||||
Returns:
|
||||
(base_experts, lora_dict)
|
||||
|
||||
``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)``
|
||||
tuples, or is empty if no LoRA is active.
|
||||
"""
|
||||
wrappers = {}
|
||||
module = experts_module
|
||||
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
|
||||
param_name = getattr(module, "parameter_name", None)
|
||||
if param_name is not None:
|
||||
wrappers[param_name] = module
|
||||
module = module.base_layer
|
||||
|
||||
base_experts = module
|
||||
lora_dict = {}
|
||||
|
||||
for param_name, wrapper in wrappers.items():
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
|
||||
if lora_A is not None:
|
||||
lora_dict[param_name] = (lora_A, lora_B, scaling)
|
||||
|
||||
return base_experts, lora_dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA weight materialization autograd function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoELoRAMaterialize(torch.autograd.Function):
|
||||
"""Materialize effective weight W_eff = W + scaling * (B @ A) per expert.
|
||||
|
||||
Inserts into the autograd graph between PEFT's LoRA parameters and
|
||||
SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff,
|
||||
which this function decomposes into dA and dB via the chain rule.
|
||||
|
||||
Weight layouts (PEFT rank-major):
|
||||
base_weight: [E, dim1, dim2] (frozen expert parameter)
|
||||
lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e)
|
||||
lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e)
|
||||
|
||||
Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2]
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
base_weight: torch.Tensor,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> torch.Tensor:
|
||||
E, dim1, dim2 = base_weight.shape
|
||||
r = lora_A.shape[0] // E
|
||||
assert lora_A.shape[0] == r * E, (
|
||||
f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})"
|
||||
)
|
||||
|
||||
# Reshape PEFT rank-major to per-expert batched format
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2]
|
||||
delta = torch.bmm(B_3d, A_3d)
|
||||
|
||||
W_eff = base_weight + scaling * delta
|
||||
|
||||
ctx.save_for_backward(lora_A, lora_B)
|
||||
ctx.scaling = scaling
|
||||
ctx.E = E
|
||||
ctx.r = r
|
||||
|
||||
return W_eff
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_W_eff: torch.Tensor):
|
||||
lora_A, lora_B = ctx.saved_tensors
|
||||
scaling = ctx.scaling
|
||||
E = ctx.E
|
||||
r = ctx.r
|
||||
|
||||
_, dim1, dim2 = grad_W_eff.shape
|
||||
|
||||
# Reshape to per-expert (same as forward)
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# dA_e = scaling * B_e^T @ dW_e
|
||||
# [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2]
|
||||
d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff)
|
||||
|
||||
# dB_e = scaling * dW_e @ A_e^T
|
||||
# [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r]
|
||||
d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2))
|
||||
|
||||
# Reshape back to PEFT rank-major layout
|
||||
d_lora_A = d_A_3d.reshape(E * r, dim2)
|
||||
d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r)
|
||||
|
||||
return None, d_lora_A, d_lora_B, None
|
||||
|
||||
|
||||
def materialize_expert_lora(
|
||||
base_weight: torch.Tensor,
|
||||
lora_params: Optional[tuple],
|
||||
) -> torch.Tensor:
|
||||
"""Materialize effective expert weight with optional LoRA delta.
|
||||
|
||||
Args:
|
||||
base_weight: [E, dim1, dim2] frozen expert parameter
|
||||
lora_params: (lora_A, lora_B, scaling) or None
|
||||
|
||||
Returns:
|
||||
W_eff if lora_params is not None, else base_weight unchanged.
|
||||
"""
|
||||
if lora_params is None:
|
||||
return base_weight
|
||||
lora_A, lora_B, scaling = lora_params
|
||||
return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling)
|
||||
@@ -1,576 +0,0 @@
|
||||
"""
|
||||
Routing functions for SonicMoE integration.
|
||||
|
||||
Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
|
||||
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
|
||||
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
|
||||
- gemma4_text: RMSNorm -> scale -> proj -> softmax -> topk -> renorm -> per_expert_scale (gemma4_routing)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .lora import unwrap_gate_lora
|
||||
|
||||
|
||||
def get_model_moe_config(model_type: str):
|
||||
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||
|
||||
Args:
|
||||
model_type: HuggingFace model type string.
|
||||
|
||||
Returns:
|
||||
routing_fn: Callable or None. None signals the fused
|
||||
moe_TC_softmax_topk_layer path (topk -> softmax models).
|
||||
activation: SonicMoE ActivationType enum value.
|
||||
router_attr: Name of the router module attribute on the MoE block
|
||||
(e.g. "gate" or "router").
|
||||
|
||||
The activation type cannot be derived from config.hidden_act because
|
||||
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
|
||||
(act_fn(gate) * up pattern). So we specify it per model type.
|
||||
"""
|
||||
from sonicmoe.enums import ActivationType
|
||||
|
||||
if model_type in (
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
"qwen3_next",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_omni_moe",
|
||||
"olmoe",
|
||||
"mixtral",
|
||||
"minimax",
|
||||
):
|
||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("mistral4",):
|
||||
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in (
|
||||
"glm_moe_dsa",
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
"glm4_moe_lite",
|
||||
"glm4v_moe",
|
||||
"minimax_m2",
|
||||
):
|
||||
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("ernie4_5_moe",):
|
||||
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("hunyuan_v1_moe",):
|
||||
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("gemma4_text",):
|
||||
return gemma4_routing, ActivationType.GEGLU, "router"
|
||||
# Fused topk -> softmax path (routing_fn=None):
|
||||
# elif model_type in ("gpt_oss",):
|
||||
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||
# # ignores (it only takes router_w, not bias). Also has transposed
|
||||
# # weight layout [E, H, 2*I] and custom GLU activation.
|
||||
# return None, ActivationType.SWIGLU, "router"
|
||||
else:
|
||||
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
|
||||
|
||||
|
||||
def softmax_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts.
|
||||
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
|
||||
# Cast to float32 to match LoRA delta dtype (PEFT computes in fp32).
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
# Expert sorting is handled internally by general_routing_router_metadata.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def sigmoid_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
|
||||
|
||||
Supports two variants:
|
||||
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
|
||||
bias on gate, group-based masking before topk.
|
||||
- **No group selection** (minimax_m2): n_group == 1 (or absent),
|
||||
bias on moe_block, straight topk from all experts.
|
||||
|
||||
Final routing weights come from the original sigmoid scores (not
|
||||
bias-corrected), with optional renormalization and scaling.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.* and
|
||||
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
|
||||
.routed_scaling_factor, .n_routed_experts)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
# Compute router logits and sigmoid probabilities
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
raise AttributeError(
|
||||
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||
f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_bias_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm.
|
||||
|
||||
Differs from standard softmax_topk_routing in three ways:
|
||||
1. A learned e_score_correction_bias is added to softmax probs *before* topk
|
||||
(selection uses biased scores, but final weights use original probs).
|
||||
2. The bias is applied via gate.moe_statics module (not a raw tensor).
|
||||
3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon.
|
||||
|
||||
Reference: Ernie4_5_MoeTopKRouter.forward in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax (force float32 for numerical stability)
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (via moe_statics module)
|
||||
scores_for_choice = base_gate.moe_statics(router_probs) # [T, E]
|
||||
|
||||
# Select top-k experts using biased scores
|
||||
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
|
||||
|
||||
# Gather weights from *original* (unbiased) softmax probs
|
||||
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
|
||||
|
||||
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
|
||||
norm_min = getattr(base_gate, "norm_min", 1e-20)
|
||||
top_values = top_values / torch.clamp(
|
||||
top_values.sum(dim=-1, keepdim=True), min=norm_min
|
||||
)
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_limited_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale.
|
||||
|
||||
Differs from softmax_group_topk_routing (Mistral4) in several ways:
|
||||
1. Uses ``num_group`` attribute (not ``n_group``).
|
||||
2. Group score = max per group (not sum of top-2).
|
||||
3. Supports ``greedy`` method (plain topk without groups).
|
||||
4. No renormalization — just ``topk_weight * routed_scaling_factor``.
|
||||
5. Gate is ``nn.Linear`` (access weight via ``gate.weight``).
|
||||
|
||||
Reference: DeepseekV2Moe.route_tokens_to_experts in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate, .num_group,
|
||||
.topk_group, .top_k, .topk_method, .routed_scaling_factor)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
num_group = getattr(moe_block, "num_group", 1)
|
||||
num_experts = gate_weight.shape[0]
|
||||
topk_method = getattr(moe_block, "topk_method", "greedy")
|
||||
|
||||
# Compute logits in float32 and softmax
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
if topk_method == "greedy" or num_group == 1:
|
||||
topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False)
|
||||
elif topk_method == "group_limited_greedy":
|
||||
# Guard: selected groups must contain enough experts for topk
|
||||
group_size = num_experts // num_group
|
||||
if moe_block.topk_group * group_size < K:
|
||||
raise ValueError(
|
||||
f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size "
|
||||
f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). "
|
||||
f"Not enough experts in selected groups for topk selection."
|
||||
)
|
||||
# Group selection: pick top groups by max score per group
|
||||
group_scores = (
|
||||
router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values
|
||||
) # [T, num_group]
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(T, num_group, num_experts // num_group)
|
||||
.reshape(T, -1)
|
||||
)
|
||||
tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"DeepSeek V2: unsupported topk_method '{topk_method}'. "
|
||||
f"Expected 'greedy' or 'group_limited_greedy'."
|
||||
)
|
||||
|
||||
# Scale only — no renormalization (weights won't sum to 1.0 per token).
|
||||
# This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior.
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_topk_wg_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg).
|
||||
|
||||
Differs from standard softmax_topk_routing in:
|
||||
1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``).
|
||||
2. ``top_k`` is on ``moe_block`` (not ``gate``).
|
||||
3. Always renormalizes (no ``norm_topk_prob`` flag).
|
||||
|
||||
Reference: HunYuanMoEV1Moe.route_tokens_to_experts and
|
||||
HunYuanMoEV1Gate.forward in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
|
||||
# Gate computes logits via gate.wg (nn.Linear, float32)
|
||||
# Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container
|
||||
base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg)
|
||||
router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E]
|
||||
if wg_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), wg_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
|
||||
# Always renormalize (HunYuan V1 has no norm_topk_prob flag)
|
||||
top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def gemma4_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Gemma4-style routing: RMSNorm → scale → proj → softmax → topk → renorm → per_expert_scale.
|
||||
|
||||
Gemma4's router (``Gemma4TextRouter``) has a unique structure:
|
||||
1. RMSNorm (without learnable scale) on hidden states
|
||||
2. Multiply by ``scale * hidden_size**-0.5``
|
||||
3. Linear projection to expert scores
|
||||
4. Softmax → topk
|
||||
5. Normalize top-k weights to sum to 1
|
||||
6. Multiply by per-expert learned scales
|
||||
|
||||
The router lives at ``moe_block.router`` (not ``moe_block.gate``).
|
||||
LoRA on the router targets ``router.proj`` (nn.Linear).
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.router)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
router = moe_block.router
|
||||
|
||||
# Unwrap PEFT LoRA on router.proj (the nn.Linear)
|
||||
_, proj_weight, proj_lora_delta = unwrap_gate_lora(router.proj)
|
||||
|
||||
T, _ = hidden_states.shape
|
||||
K = router.top_k if hasattr(router, "top_k") else router.config.top_k_experts
|
||||
|
||||
# Reproduce Gemma4TextRouter.forward:
|
||||
# 1. RMSNorm (no scale) + scale param * hidden_size**-0.5
|
||||
normed = router.norm(hidden_states)
|
||||
scaled = normed * router.scale * router.scalar_root_size
|
||||
|
||||
# 2. Project to expert scores
|
||||
router_logits = F.linear(scaled.float(), proj_weight.float()) # [T, E]
|
||||
if proj_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
scaled.float(), proj_lora_delta.float()
|
||||
)
|
||||
|
||||
# 3. Softmax → topk
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K]
|
||||
|
||||
# 4. Normalize top-k weights
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 5. Per-expert scale
|
||||
top_values = top_values * router.per_expert_scale[top_indices]
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
@@ -61,31 +61,20 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from axolotl.integrations.kernels.constants import (
|
||||
SPARSE_MOE_BLOCK,
|
||||
is_experts_only_model,
|
||||
)
|
||||
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and not is_experts_only_model(moe_model_type)
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
if is_experts_only_model(moe_model_type):
|
||||
# Models like Gemma4 where MoE is embedded in the decoder layer
|
||||
# — register ScatterMoE in the ExpertsInterface so that
|
||||
# @use_experts_implementation dispatches to it.
|
||||
self._register_experts_interface()
|
||||
cfg.experts_implementation = "scattermoe"
|
||||
else:
|
||||
self._kernelize_model(moe_model_type)
|
||||
self._kernelize_model(moe_model_type)
|
||||
elif cfg.use_sonicmoe:
|
||||
if not importlib.util.find_spec("sonicmoe"):
|
||||
raise RuntimeError(
|
||||
@@ -95,24 +84,13 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
_check_sonicmoe_gpu_compat()
|
||||
|
||||
if is_experts_only_model(moe_model_type):
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.gemma4_experts import (
|
||||
patch_gemma4_sonicmoe,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(
|
||||
f"Applying SonicMoE experts-level patch for model type: {moe_model_type}"
|
||||
)
|
||||
patch_gemma4_sonicmoe()
|
||||
else:
|
||||
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
base_model_type=cfg.model_config_type,
|
||||
)
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
)
|
||||
|
||||
def _register_kernels(self):
|
||||
from kernels import (
|
||||
@@ -160,16 +138,3 @@ class KernelsPlugin(BasePlugin):
|
||||
replace_kernel_forward_from_hub(
|
||||
model_moe_cls, "HFScatterMoEParallelExperts"
|
||||
)
|
||||
|
||||
def _register_experts_interface(self):
|
||||
"""Register ScatterMoE in the transformers ExpertsInterface.
|
||||
|
||||
This allows @use_experts_implementation-decorated Experts classes
|
||||
to dispatch to ScatterMoE when config._experts_implementation == "scattermoe".
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.gemma4_experts import (
|
||||
register_scattermoe_experts,
|
||||
)
|
||||
|
||||
register_scattermoe_experts()
|
||||
LOG.info("Registered 'scattermoe' in transformers ExpertsInterface")
|
||||
|
||||
@@ -28,72 +28,20 @@ import torch.nn.functional as F
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .lora import (
|
||||
has_lora,
|
||||
materialize_expert_lora,
|
||||
unwrap_experts_lora,
|
||||
unwrap_gate_lora,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_expert_weights(experts_module):
|
||||
"""Extract expert weights, applying LoRA materialization if PEFT is active.
|
||||
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
|
||||
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
|
||||
|
||||
Returns:
|
||||
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
|
||||
Args:
|
||||
model_type: The HuggingFace model type (e.g. "qwen3_moe").
|
||||
torch_compile: If True, wrap routing functions with torch.compile
|
||||
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
|
||||
"""
|
||||
if has_lora(experts_module):
|
||||
base_experts, lora_dict = unwrap_experts_lora(experts_module)
|
||||
gate_up = materialize_expert_lora(
|
||||
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
|
||||
)
|
||||
down = materialize_expert_lora(
|
||||
base_experts.down_proj, lora_dict.get("down_proj")
|
||||
)
|
||||
else:
|
||||
gate_up = experts_module.gate_up_proj
|
||||
down = experts_module.down_proj
|
||||
|
||||
# Permute to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
|
||||
|
||||
|
||||
def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str):
|
||||
"""Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders."""
|
||||
if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text":
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
register_checkpoint_conversion_mapping,
|
||||
)
|
||||
from transformers.core_model_loading import WeightRenaming
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
text_mapping = get_checkpoint_conversion_mapping(model_type)
|
||||
if text_mapping and isinstance(text_mapping[0], WeightRenaming):
|
||||
text_mapping.pop(0)
|
||||
register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True)
|
||||
LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode")
|
||||
|
||||
|
||||
def patch_sonicmoe(
|
||||
model_type: str,
|
||||
torch_compile: bool = False,
|
||||
base_model_type: str | None = None,
|
||||
):
|
||||
"""Patch SparseMoeBlock for SonicMoE support."""
|
||||
from .routing import get_model_moe_config
|
||||
from .weight_converter import register_sonicmoe_weight_converter
|
||||
|
||||
_fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type)
|
||||
|
||||
routing_fn, activation, router_attr = get_model_moe_config(model_type)
|
||||
|
||||
if torch_compile and routing_fn is not None:
|
||||
@@ -165,10 +113,11 @@ def _make_general_forward(moe_cls, routing_fn, activation):
|
||||
hidden_states_flat, self
|
||||
)
|
||||
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
E = gate_up_weight.shape[-1]
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
@@ -212,30 +161,22 @@ def _make_fused_forward(moe_cls, activation, router_attr):
|
||||
# Shared expert (computed early, matching original model ordering)
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
# Unwrap router for attribute access + optional LoRA delta
|
||||
raw_router = getattr(self, router_attr)
|
||||
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
|
||||
if router_lora_delta is not None:
|
||||
# Materialize local tensor to avoid DTensor + Tensor add under FSDP
|
||||
if hasattr(router_weight, "to_local"):
|
||||
router_weight = router_weight.to_local()
|
||||
effective_router_weight = router_weight + router_lora_delta
|
||||
else:
|
||||
effective_router_weight = router_weight
|
||||
router = getattr(self, router_attr)
|
||||
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
|
||||
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
|
||||
hidden_states_flat,
|
||||
effective_router_weight,
|
||||
router.weight,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
base_router.top_k,
|
||||
router.top_k,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
activation,
|
||||
False, # is_inference_mode
|
||||
278
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
278
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Routing functions for SonicMoE integration.
|
||||
|
||||
Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def get_model_moe_config(model_type: str):
|
||||
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||
|
||||
Args:
|
||||
model_type: HuggingFace model type string.
|
||||
|
||||
Returns:
|
||||
routing_fn: Callable or None. None signals the fused
|
||||
moe_TC_softmax_topk_layer path (topk -> softmax models).
|
||||
activation: SonicMoE ActivationType enum value.
|
||||
router_attr: Name of the router module attribute on the MoE block
|
||||
(e.g. "gate" or "router").
|
||||
|
||||
The activation type cannot be derived from config.hidden_act because
|
||||
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
|
||||
(act_fn(gate) * up pattern). So we specify it per model type.
|
||||
"""
|
||||
from sonicmoe.enums import ActivationType
|
||||
|
||||
if model_type in (
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
"qwen3_5_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_omni_moe",
|
||||
"olmoe",
|
||||
"mixtral",
|
||||
"minimax",
|
||||
):
|
||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("mistral4",):
|
||||
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in (
|
||||
"glm_moe_dsa",
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
"glm4_moe_lite",
|
||||
"glm4v_moe",
|
||||
"minimax_m2",
|
||||
):
|
||||
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("ernie4_5_moe",):
|
||||
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("deepseek_v2",):
|
||||
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
|
||||
# # (not n_group), gate is nn.Linear (not a router class).
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("hunyuan_v1_moe",):
|
||||
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
|
||||
# # top_k on block not gate, creates scatter routing matrix.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# Fused topk -> softmax path (routing_fn=None):
|
||||
# elif model_type in ("gpt_oss",):
|
||||
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||
# # ignores (it only takes router_w, not bias). Also has transposed
|
||||
# # weight layout [E, H, 2*I] and custom GLU activation.
|
||||
# return None, ActivationType.SWIGLU, "router"
|
||||
else:
|
||||
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
|
||||
|
||||
|
||||
def softmax_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts per token
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
# top_values = top_values.to(router_probs.dtype)
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||
# Expert sorting is handled internally by general_routing_router_metadata.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def sigmoid_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
|
||||
|
||||
Supports two variants:
|
||||
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
|
||||
bias on gate, group-based masking before topk.
|
||||
- **No group selection** (minimax_m2): n_group == 1 (or absent),
|
||||
bias on moe_block, straight topk from all experts.
|
||||
|
||||
Final routing weights come from the original sigmoid scores (not
|
||||
bias-corrected), with optional renormalization and scaling.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.* and
|
||||
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
|
||||
.routed_scaling_factor, .n_routed_experts)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
# Compute router logits and sigmoid probabilities
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
raise AttributeError(
|
||||
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs.
|
||||
# Token indices are naturally sorted ascending from the [T, K] layout.
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
@@ -129,41 +129,15 @@ class InterleavedToConcatenated(ConversionOps):
|
||||
return ConcatenatedToInterleaved(self.dim)
|
||||
|
||||
|
||||
def _make_same_key_interleave_converter():
|
||||
"""Create a WeightConverter that interleaves an already-fused gate_up_proj."""
|
||||
from transformers.core_model_loading import WeightConverter
|
||||
|
||||
return WeightConverter(
|
||||
source_patterns="mlp.experts.gate_up_proj",
|
||||
target_patterns="mlp.experts.gate_up_proj",
|
||||
operations=[ConcatenatedToInterleaved(dim=1)],
|
||||
)
|
||||
|
||||
|
||||
def _has_same_key_interleave(mapping) -> bool:
|
||||
"""Check whether the mapping already has a same-key gate_up_proj interleave converter."""
|
||||
for conv in mapping:
|
||||
if (
|
||||
hasattr(conv, "source_patterns")
|
||||
and conv.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and conv.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and hasattr(conv, "operations")
|
||||
and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def register_sonicmoe_weight_converter(model_type: str):
|
||||
"""Register weight converters to interleave gate_up_proj for SonicMoE.
|
||||
"""Override the conversion mapping to add interleave step for gate_up_proj.
|
||||
|
||||
Handles two checkpoint formats:
|
||||
1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the
|
||||
existing merge chain (MergeModulelist -> Concatenate -> Interleave).
|
||||
2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key
|
||||
converter (gate_up_proj -> gate_up_proj with Interleave).
|
||||
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
|
||||
converter chain. For example, qwen3_moe's chain becomes:
|
||||
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
|
||||
|
||||
The loader matches whichever source pattern exists in the checkpoint.
|
||||
The reverse is auto-generated for saving:
|
||||
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
|
||||
"""
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
@@ -171,32 +145,37 @@ def register_sonicmoe_weight_converter(model_type: str):
|
||||
)
|
||||
|
||||
existing = get_checkpoint_conversion_mapping(model_type)
|
||||
|
||||
if existing is None:
|
||||
# No mapping at all — create one with just the same-key converter
|
||||
mapping = [_make_same_key_interleave_converter()]
|
||||
register_checkpoint_conversion_mapping(model_type, mapping)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
LOG.warning(
|
||||
f"No conversion mapping found for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
|
||||
# Append interleave to any existing many-to-one merge chain
|
||||
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
|
||||
patched = False
|
||||
for converter in existing:
|
||||
if hasattr(converter, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in converter.target_patterns
|
||||
):
|
||||
has_separate_sources = any(
|
||||
"gate_proj" in pat or "up_proj" in pat
|
||||
for pat in converter.source_patterns
|
||||
)
|
||||
if has_separate_sources and not any(
|
||||
# Guard against double registration (e.g. plugin reloaded)
|
||||
if any(
|
||||
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
|
||||
):
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
LOG.info(
|
||||
f"SonicMoE weight converter already registered for '{model_type}'"
|
||||
)
|
||||
return
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
patched = True
|
||||
break
|
||||
|
||||
# Also add a same-key converter for already-fused checkpoints
|
||||
if not _has_same_key_interleave(existing):
|
||||
existing.append(_make_same_key_interleave_converter())
|
||||
if not patched:
|
||||
LOG.warning(
|
||||
f"Could not find gate_up_proj converter for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
|
||||
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
@@ -1,412 +0,0 @@
|
||||
# NeMo Gym Integration for Axolotl
|
||||
|
||||
Train LLMs with reinforcement learning using [NVIDIA NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) environments as reward sources. NeMo Gym provides 50+ verified RL environments spanning math, coding, tool-use, reasoning, and safety — each with deterministic reward signals.
|
||||
|
||||
## Validated Training Paths
|
||||
|
||||
| Path | Speed | Multi-turn | Architecture |
|
||||
|------|-------|------------|--------------|
|
||||
| **Async GRPO + Data Producer** | Fastest (3x) | Yes | `NemoGymDataProducer` replaces vLLM generation |
|
||||
| Standard GRPO + Data Producer | Baseline | Yes | Same producer, no async prefetch |
|
||||
| Standard GRPO + /verify | Simplest | No | Reward function calls /verify directly |
|
||||
| FSDP2 + /verify (2 GPU) | Distributed | No | `fsdp_version: 2` |
|
||||
|
||||
Multi-turn uses `nemo_gym_multi_turn: true` which auto-enables the async trainer's
|
||||
data producer protocol. The plugin's `NemoGymDataProducer` calls NeMo Gym agent `/run`
|
||||
endpoints and returns `RolloutDataset` with proper IS correction, env_mask, and rewards.
|
||||
|
||||
All paths tested end-to-end with Qwen3-0.6B + LoRA, logged to wandb project `nemo-gym-rl`.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) package manager (for NeMo Gym's venv)
|
||||
- Two GPUs recommended (one for vLLM server, one for training)
|
||||
|
||||
### 1. Set Up NeMo Gym
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
|
||||
cd ~/Gym
|
||||
uv venv --python 3.12 && source .venv/bin/activate && uv sync
|
||||
|
||||
# Fix pycosat build (GCC 13+)
|
||||
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
|
||||
|
||||
# Pre-build resource server venvs
|
||||
for dir in resources_servers/reasoning_gym resources_servers/example_single_tool_call responses_api_models/vllm_model responses_api_agents/simple_agent; do
|
||||
uv venv --seed --allow-existing --python 3.12 $dir/.venv
|
||||
CFLAGS="" uv pip install --python $dir/.venv/bin/python pycosat --no-build-isolation 2>/dev/null
|
||||
uv pip install --python $dir/.venv/bin/python -e . "ray[default]==2.52.1"
|
||||
done
|
||||
|
||||
# Install extra deps for reasoning_gym
|
||||
uv pip install --python resources_servers/reasoning_gym/.venv/bin/python \
|
||||
reasoning-gym matplotlib pillow cycler contourpy kiwisolver
|
||||
```
|
||||
|
||||
### 2. Multi-Turn with Async GRPO (Recommended — Fastest Path)
|
||||
|
||||
This is the fully validated, highest-performance path. NeMo Gym's agent server handles
|
||||
multi-turn tool execution while axolotl's async GRPO prefetches data in background threads.
|
||||
|
||||
**Step 1: Create the NeMo Gym agent config**
|
||||
|
||||
Create `~/Gym/configs/axolotl_tool_calling.yaml`:
|
||||
```yaml
|
||||
# Resource server (tools + verify)
|
||||
example_single_tool_call:
|
||||
resources_servers:
|
||||
example_single_tool_call:
|
||||
entrypoint: app.py
|
||||
domain: agent
|
||||
verified: false
|
||||
|
||||
# Model server proxy (forwards to your vLLM)
|
||||
policy_model:
|
||||
responses_api_models:
|
||||
vllm_model:
|
||||
entrypoint: app.py
|
||||
base_url: http://localhost:8000/v1
|
||||
api_key: dummy_key
|
||||
model: Qwen/Qwen3-0.6B # Must match your training model
|
||||
return_token_id_information: true
|
||||
uses_reasoning_parser: false
|
||||
|
||||
# Agent server (orchestrates multi-turn via /run)
|
||||
example_single_tool_call_simple_agent:
|
||||
responses_api_agents:
|
||||
simple_agent:
|
||||
entrypoint: app.py
|
||||
resources_server:
|
||||
type: resources_servers
|
||||
name: example_single_tool_call
|
||||
model_server:
|
||||
type: responses_api_models
|
||||
name: policy_model
|
||||
datasets:
|
||||
- name: weather
|
||||
type: example
|
||||
jsonl_fpath: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
```
|
||||
|
||||
**Step 2: Start three services**
|
||||
|
||||
```bash
|
||||
# Terminal 1: vLLM OpenAI server on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B --max-model-len 2048 --gpu-memory-utilization 0.85
|
||||
|
||||
# Terminal 2: NeMo Gym (resource server + model proxy + agent)
|
||||
cd ~/Gym && .venv/bin/ng_run \
|
||||
"+config_paths=[configs/axolotl_tool_calling.yaml]" "+skip_venv_if_present=true"
|
||||
|
||||
# Terminal 3: Training on GPU 1
|
||||
cd experiments && CUDA_VISIBLE_DEVICES=1 CUDA_HOME=$HOME/env-claude-cu130/cuda_shim \
|
||||
axolotl train nemo_gym_async_agent.yaml
|
||||
```
|
||||
|
||||
**Step 3: Training config** (`nemo_gym_async_agent.yaml`):
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-0.6B
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
||||
sequence_len: 2048
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: server
|
||||
vllm_server_host: localhost
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 5
|
||||
# Async GRPO — 3x faster than standard
|
||||
use_data_producer: true
|
||||
async_prefetch: true
|
||||
num_generations: 4
|
||||
max_completion_length: 512
|
||||
temperature: 0.8
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_multi_turn: true
|
||||
nemo_gym_verify_timeout: 120
|
||||
nemo_gym_datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
server_name: example_single_tool_call
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.85
|
||||
max_model_len: 2048
|
||||
tensor_parallel_size: 1
|
||||
|
||||
learning_rate: 5e-6
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 30
|
||||
gradient_checkpointing: true
|
||||
bf16: true
|
||||
output_dir: ./outputs/nemo_gym_async
|
||||
|
||||
use_wandb: true
|
||||
wandb_project: nemo-gym-rl
|
||||
```
|
||||
|
||||
### 3. Single-Turn Training (Simplest — No Agent Server Needed)
|
||||
|
||||
For environments that only need single-turn verify (math, coding challenges), you don't need
|
||||
an agent server. The plugin's reward function calls `/verify` directly.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_mode: colocate
|
||||
vllm_enable_sleep_mode: false
|
||||
num_generations: 8
|
||||
max_completion_length: 128
|
||||
temperature: 0.9
|
||||
reward_funcs:
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
||||
|
||||
nemo_gym_enabled: true
|
||||
nemo_gym_auto_start: false
|
||||
nemo_gym_head_port: 11000
|
||||
nemo_gym_datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
server_name: reasoning_gym
|
||||
|
||||
datasets:
|
||||
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
||||
type: chat_template
|
||||
field_messages: responses_create_params.input
|
||||
message_field_content: content
|
||||
message_field_role: role
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.3
|
||||
max_model_len: 512
|
||||
tensor_parallel_size: 1
|
||||
|
||||
learning_rate: 1e-5
|
||||
micro_batch_size: 4
|
||||
gradient_accumulation_steps: 2
|
||||
max_steps: 50
|
||||
output_dir: ./outputs/nemo_gym_arithmetic
|
||||
```
|
||||
|
||||
Only needs `ng_run` with resource servers (no agent config):
|
||||
```bash
|
||||
cd ~/Gym && ng_run "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" "+skip_venv_if_present=true"
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Single-Turn
|
||||
```text
|
||||
axolotl train → GRPO Trainer generates completions
|
||||
→ NeMo Gym plugin reward_fn calls POST /verify on resource server
|
||||
→ reward flows back to GRPO for advantage computation
|
||||
```
|
||||
|
||||
### Multi-Turn (Agent /run)
|
||||
```text
|
||||
┌─────────────┐ ┌──────────────┐ ┌──────────────────┐
|
||||
│ axolotl │ │ NeMo Gym │────▶│ vLLM OpenAI │
|
||||
│ train │────▶│ Agent /run │◀────│ Server (GPU 0) │
|
||||
│ (GPU 1) │ │ │ │ /v1/completions │
|
||||
└─────────────┘ └──────┬───────┘ └──────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────┐
|
||||
│ Resource │
|
||||
│ Server │
|
||||
│ (tools + │
|
||||
│ verify) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
The agent server orchestrates the entire multi-turn loop:
|
||||
1. Calls our vLLM server for model generation
|
||||
2. Parses tool calls from model output
|
||||
3. Executes tools against resource servers
|
||||
4. Feeds tool results back to the model
|
||||
5. Repeats until done, then calls /verify for reward
|
||||
6. Returns token IDs + logprobs + reward to our rollout_func
|
||||
|
||||
### Data Producer Architecture (Multi-Turn)
|
||||
|
||||
When `nemo_gym_multi_turn: true`, the plugin automatically forces `use_data_producer: true`
|
||||
which selects the `AxolotlAsyncGRPOTrainer`. The plugin then swaps the trainer's data
|
||||
producer with `NemoGymDataProducer`, which:
|
||||
|
||||
1. Gets a prompt batch from the dataset iterator
|
||||
2. Expands by `num_generations` (one agent call per rollout)
|
||||
3. Calls NeMo Gym agents via async HTTP (`aiohttp.gather`)
|
||||
4. Parses responses into padded tensors (`RolloutDataset`)
|
||||
5. Returns with `_pending_policy_logps=True` for deferred scoring
|
||||
|
||||
The main thread then runs `_compute_deferred_scores()` which:
|
||||
- Computes **policy logprobs** on the training model (GPU forward pass)
|
||||
- Computes **IS correction** using agent's sampling logprobs vs training model logprobs
|
||||
- Computes advantages with group-level normalization
|
||||
- All downstream features work: replay buffer, re-roll, streaming, zero-adv skip
|
||||
|
||||
With `async_prefetch: true`, the data producer runs in a background thread — giving ~3x
|
||||
speedup as generation and training overlap. With `async_prefetch: false`, it runs
|
||||
synchronously on the main thread (still uses the data producer protocol).
|
||||
|
||||
### Weight Sync (LoRA Mode)
|
||||
|
||||
With `vllm_lora_sync: true`, the plugin (or async trainer) replaces NCCL-based weight
|
||||
sync with filesystem + HTTP:
|
||||
|
||||
1. `accelerator.get_state_dict()` gathers LoRA weights from all ranks
|
||||
2. Rank 0 saves adapter to `/tmp/lora_sync_*/vN/`
|
||||
3. Rank 0 POSTs to `/set_lora_adapter/` on vLLM server
|
||||
4. vLLM loads adapter natively via Punica kernels
|
||||
5. Only ~40MB transferred (vs multiple GBs for full model weights)
|
||||
|
||||
### Multi-Environment Support
|
||||
|
||||
Datasets support per-row environment routing via `agent_ref`:
|
||||
```jsonl
|
||||
{"agent_ref": {"name": "reasoning_gym"}, "responses_create_params": {...}}
|
||||
{"agent_ref": {"name": "instruction_following"}, "responses_create_params": {...}}
|
||||
```
|
||||
|
||||
Or use the simpler per-dataset routing:
|
||||
```yaml
|
||||
nemo_gym_datasets:
|
||||
- path: reasoning_data.jsonl
|
||||
server_name: reasoning_gym
|
||||
- path: tool_data.jsonl
|
||||
server_name: example_single_tool_call
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `nemo_gym_enabled` | bool | `null` | Enable the NeMo Gym integration |
|
||||
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
|
||||
| `nemo_gym_auto_clone` | bool | `true` | Auto-clone NeMo Gym repo if missing |
|
||||
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
|
||||
| `nemo_gym_config_paths` | list[str] | — | Server config YAMLs (relative to gym_dir) |
|
||||
| `nemo_gym_datasets` | list[dict] | required | Dataset configs with `path` and optional `server_name` |
|
||||
| `nemo_gym_head_port` | int | `11000` | Head server port |
|
||||
| `nemo_gym_server_timeout` | int | `360` | Server startup timeout (seconds) |
|
||||
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
|
||||
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent /run |
|
||||
|
||||
### Dataset JSONL Format
|
||||
|
||||
Each line must have `responses_create_params` with `input` messages:
|
||||
```json
|
||||
{
|
||||
"responses_create_params": {
|
||||
"input": [{"role": "user", "content": "What's the weather in SF?"}],
|
||||
"tools": [{"name": "get_weather", "type": "function", "strict": true, "parameters": {...}}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For multi-turn agent routing, include `agent_ref`:
|
||||
```json
|
||||
{"agent_ref": {"name": "my_agent"}, "responses_create_params": {...}}
|
||||
```
|
||||
|
||||
Note: Tool definitions MUST include `"strict": true` and `"additionalProperties": false` for NeMo Gym agent compatibility.
|
||||
|
||||
### Reward Functions
|
||||
|
||||
The plugin provides two built-in reward functions — no user code needed:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
# Multi-turn (nemo_gym_multi_turn: true):
|
||||
# Passthrough — agent /run already computed the reward
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_env
|
||||
|
||||
# Single-turn (nemo_gym_multi_turn: false):
|
||||
# Calls /verify endpoints on NeMo Gym resource servers
|
||||
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
||||
```
|
||||
|
||||
Both are also importable from Python:
|
||||
|
||||
```python
|
||||
from axolotl.integrations.nemo_gym import reward_env, reward_nemo_gym_verify
|
||||
```
|
||||
|
||||
## Known Issues / Troubleshooting
|
||||
|
||||
### NeMo Gym Server Setup
|
||||
- **pycosat build failure**: `CFLAGS="" uv pip install pycosat --no-build-isolation`
|
||||
- **Ray version mismatch**: Pin `ray[default]==2.52.1` in all server venvs
|
||||
- **Pre-build venvs**: `ng_run` creates per-server venvs via Ray. Pre-build them and use `+skip_venv_if_present=true`
|
||||
- **Tool `strict` field required**: Agent server validates tool definitions require `strict: true`
|
||||
|
||||
### vLLM / Weight Sync
|
||||
- **Start vLLM with LoRA + tool calling + runtime loading**:
|
||||
```bash
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 \
|
||||
CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-4B-Instruct-2507 \
|
||||
--max-model-len 4096 \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-lora --max-lora-rank 64 \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
```
|
||||
- **`VLLM_ALLOW_RUNTIME_LORA_UPDATING=1`**: Required for `vllm_lora_sync: true`. Without it, vLLM won't expose the `/v1/load_lora_adapter` endpoint and weight sync will fail silently. The plugin warns if this endpoint is missing.
|
||||
- **`--enable-lora`**: Enables LoRA adapter support in vLLM
|
||||
- **`--enable-auto-tool-choice --tool-call-parser hermes`**: Required for Qwen3 tool calling
|
||||
- **`max_model_len` must be > `max_completion_length`**: Leave room for prompt tokens (~200). If equal, the NeMo Gym model proxy gets a 400 error and returns empty completions.
|
||||
- **`CUDA_HOME` required**: DeepSpeed import needs it for the nvcc shim
|
||||
- **NCCL weight sync broken with vLLM 0.17**: Use `vllm_lora_sync: true` (filesystem + HTTP via `/v1/load_lora_adapter`)
|
||||
|
||||
### Multi-Turn
|
||||
- **Agent server required**: Multi-turn delegates to NeMo Gym's agent server `/run` endpoint. Without an agent, the plugin falls back to single-turn `/verify`
|
||||
- **Model server proxy**: NeMo Gym needs a `responses_api_models` server that proxies to your vLLM. See the agent config example above
|
||||
|
||||
### FSDP2
|
||||
- Validated on 2 GPUs with single-turn + LoRA
|
||||
- Async field filtering: The builder automatically filters async-only config fields when using the standard GRPO trainer
|
||||
|
||||
## Comparison with Other Integrations
|
||||
|
||||
| Feature | Axolotl + NeMo Gym | Unsloth + NeMo Gym | NeMo RL (native) |
|
||||
|---------|-------------------|-------------------|-------------------|
|
||||
| Server management | Automatic | Manual (notebook) | Built-in |
|
||||
| Multi-environment | Per-row routing | Manual code | YAML config |
|
||||
| Multi-turn / tool use | Agent /run delegation | No | Agent /run (Ray) |
|
||||
| Async GRPO (3x speedup) | Yes | No | Yes |
|
||||
| LoRA sync | Filesystem + HTTP | N/A | NCCL |
|
||||
| Multi-GPU (FSDP2) | Yes | No | Yes (Ray) |
|
||||
| Config-driven | Yes | No (code) | Yes |
|
||||
@@ -1,25 +0,0 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
NeMo Gym provides RL training environments for LLMs with verification-based
|
||||
reward signals. This plugin manages the NeMo Gym server lifecycle, loads
|
||||
datasets in the NeMo Gym JSONL format, and creates reward functions that
|
||||
call the NeMo Gym /verify endpoints.
|
||||
"""
|
||||
|
||||
from .args import NemoGymArgs
|
||||
from .plugin import NemoGymPlugin
|
||||
from .rewards import reward_env, reward_nemo_gym_verify
|
||||
|
||||
__all__ = [
|
||||
"NemoGymArgs",
|
||||
"NemoGymPlugin",
|
||||
"reward_env",
|
||||
"reward_nemo_gym_verify",
|
||||
]
|
||||
@@ -1,146 +0,0 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
Input arguments for the NeMo Gym integration plugin.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class NemoGymArgs(BaseModel):
|
||||
"""Configuration args for the NeMo Gym integration."""
|
||||
|
||||
nemo_gym_enabled: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable NeMo Gym integration for environment-based RL rewards."
|
||||
},
|
||||
)
|
||||
nemo_gym_dir: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Path to the NeMo Gym repository clone. "
|
||||
"If not set and nemo_gym_auto_clone is True, clones to ~/Gym."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_auto_clone: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Automatically clone the NeMo Gym repository if not present. "
|
||||
"Defaults to True when nemo_gym_enabled is set."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_config_paths: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"List of NeMo Gym resource server config YAML paths, relative to nemo_gym_dir. "
|
||||
"Example: ['resources_servers/reasoning_gym/configs/resources_only.yaml']"
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_head_port: int | None = Field(
|
||||
default=11000,
|
||||
json_schema_extra={
|
||||
"description": "Port for the NeMo Gym head server. Defaults to 11000."
|
||||
},
|
||||
)
|
||||
nemo_gym_server_timeout: int | None = Field(
|
||||
default=360,
|
||||
json_schema_extra={
|
||||
"description": "Timeout in seconds waiting for NeMo Gym servers to start. Defaults to 360."
|
||||
},
|
||||
)
|
||||
nemo_gym_verify_timeout: int | None = Field(
|
||||
default=30,
|
||||
json_schema_extra={
|
||||
"description": "Timeout in seconds for individual /verify requests. Defaults to 30."
|
||||
},
|
||||
)
|
||||
nemo_gym_run_timeout: int | None = Field(
|
||||
default=300,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Timeout in seconds for each agent /run request (one multi-turn rollout). "
|
||||
"Prevents stuck generations (e.g. model looping on <think> tags) from "
|
||||
"blocking training indefinitely. Defaults to 300 (5 minutes)."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_datasets: list[dict] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"List of NeMo Gym dataset configs. Each entry has 'path' (JSONL file path "
|
||||
"relative to nemo_gym_dir) and optionally 'server_name' (default resource server). "
|
||||
"If the JSONL rows have agent_ref.name, that takes precedence per row, "
|
||||
"enabling multi-environment training from a single dataset file. "
|
||||
"Optional 'max_samples' to limit per dataset."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_auto_start: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Automatically start NeMo Gym resource servers. Defaults to True. "
|
||||
"Set to False if servers are already running externally."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_model_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Model name to report in verify requests. "
|
||||
"Defaults to the base_model from the main config."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_multi_turn: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enable multi-turn rollouts via NeMo Gym. When True, uses TRL's "
|
||||
"rollout_func to run multi-step interactions with tool execution. "
|
||||
"Requires use_vllm=True in TRL config. The model generates responses, "
|
||||
"tool calls are executed against resource servers, and results are "
|
||||
"fed back for the next turn. Final reward comes from /verify."
|
||||
)
|
||||
},
|
||||
)
|
||||
nemo_gym_max_turns: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Maximum number of turns per multi-turn rollout. Defaults to 10. "
|
||||
"Each turn consists of a model generation + optional tool execution."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_nemo_gym_config(cls, data):
|
||||
if data.get("nemo_gym_enabled"):
|
||||
if not data.get("nemo_gym_config_paths") and data.get(
|
||||
"nemo_gym_auto_start", True
|
||||
):
|
||||
raise ValueError(
|
||||
"nemo_gym_config_paths is required when nemo_gym_enabled=True "
|
||||
"and nemo_gym_auto_start is not False."
|
||||
)
|
||||
if not data.get("nemo_gym_datasets"):
|
||||
raise ValueError(
|
||||
"nemo_gym_datasets is required when nemo_gym_enabled=True. "
|
||||
"Provide at least one dataset with 'path' and 'server_name'."
|
||||
)
|
||||
return data
|
||||
@@ -1,226 +0,0 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
"""
|
||||
NeMo Gym Data Producer for async GRPO training.
|
||||
|
||||
Replaces GRPODataProducer to generate rollouts via NeMo Gym agent /run endpoints
|
||||
instead of vLLM. The agent handles generation, tool execution, and reward computation.
|
||||
Returns RolloutDataset in the same format as the standard producer, so all downstream
|
||||
components (deferred scoring, IS correction, streaming, replay, re-roll) work unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import GRPODataProducer, RolloutDataset
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .multi_turn import _call_agents, _parse_agent_response
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class NemoGymDataProducer(GRPODataProducer):
|
||||
"""Produces GRPO rollouts by calling NeMo Gym agent /run endpoints.
|
||||
|
||||
Drop-in replacement for GRPODataProducer. Instead of calling vLLM for generation,
|
||||
sends prompts to NeMo Gym agents which handle generation + tool execution + reward.
|
||||
Returns the same RolloutDataset format so deferred scoring, IS correction,
|
||||
replay buffer, and re-roll all work unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
agent_servers: dict[str, str],
|
||||
dataset_lookup: dict,
|
||||
request_timeout: float = 10800,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._agent_servers = agent_servers
|
||||
self._dataset_lookup = dataset_lookup
|
||||
self._request_timeout = request_timeout
|
||||
|
||||
def produce(
|
||||
self,
|
||||
model: Any,
|
||||
global_step: int,
|
||||
*,
|
||||
skip_policy_logps: bool = False,
|
||||
processing_class: Any = None,
|
||||
accelerator: Any = None,
|
||||
args: Any = None,
|
||||
_rank0_only: bool = False,
|
||||
**kwargs,
|
||||
) -> RolloutDataset | None:
|
||||
"""Generate rollouts via NeMo Gym agents.
|
||||
|
||||
Calls agent /run endpoints, parses responses into padded tensors,
|
||||
and returns a RolloutDataset for deferred scoring on the main thread.
|
||||
"""
|
||||
trainer = self._trainer
|
||||
is_main = trainer.accelerator.is_main_process
|
||||
device = trainer.accelerator.device
|
||||
|
||||
if _rank0_only and not is_main:
|
||||
return None
|
||||
|
||||
# Get prompt batch from iterator
|
||||
try:
|
||||
inputs = next(self._prompt_iter)
|
||||
except StopIteration:
|
||||
self._prompt_iter = iter(self._prompt_dl)
|
||||
inputs = next(self._prompt_iter)
|
||||
|
||||
# Extract dataset items for agent calls
|
||||
dataset_items = []
|
||||
for inp in inputs:
|
||||
prompt_text = ""
|
||||
prompt = inp.get("prompt", [])
|
||||
if isinstance(prompt, list) and prompt:
|
||||
prompt_text = (
|
||||
prompt[-1].get("content", "")
|
||||
if isinstance(prompt[-1], dict)
|
||||
else str(prompt[-1])
|
||||
)
|
||||
elif isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
|
||||
# Find the full dataset item, preserving agent_ref for routing
|
||||
full_item = self._dataset_lookup.get(prompt_text, {})
|
||||
item = full_item.get("verify_extra", {})
|
||||
if not item:
|
||||
item = {
|
||||
"responses_create_params": {
|
||||
"input": [{"role": "user", "content": prompt_text}]
|
||||
}
|
||||
}
|
||||
# Preserve agent_ref from the dataset row for _call_agents routing
|
||||
if "agent_ref" in full_item and "agent_ref" not in item:
|
||||
item["agent_ref"] = full_item["agent_ref"]
|
||||
dataset_items.append(item)
|
||||
|
||||
# Expand by num_generations (agent produces one rollout per call)
|
||||
expanded_items = []
|
||||
for item in dataset_items:
|
||||
for _ in range(self._num_generations):
|
||||
expanded_items.append(item)
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
responses = loop.run_until_complete(
|
||||
_call_agents(
|
||||
dataset_items=expanded_items,
|
||||
agent_servers=self._agent_servers,
|
||||
timeout=self._request_timeout,
|
||||
max_completion_length=trainer.max_completion_length,
|
||||
temperature=trainer.temperature,
|
||||
top_p=getattr(trainer, "top_p", None) or 0.999,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Parse responses
|
||||
eos_token_id = trainer.processing_class.eos_token_id
|
||||
prompt_ids_list = []
|
||||
completion_ids_list = []
|
||||
env_mask_list = []
|
||||
logprobs_list = []
|
||||
rewards_list = []
|
||||
|
||||
for resp in responses:
|
||||
parsed = _parse_agent_response(resp, eos_token_id)
|
||||
prompt_ids_list.append(parsed["prompt_ids"])
|
||||
completion_ids_list.append(parsed["completion_ids"])
|
||||
env_mask_list.append(parsed["env_mask"])
|
||||
logprobs_list.append(parsed["logprobs"])
|
||||
rewards_list.append(parsed["reward"])
|
||||
|
||||
# Pad to tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
||||
prompt_ids = pad(
|
||||
prompt_ids, padding_value=trainer.pad_token_id, padding_side="left"
|
||||
)
|
||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
||||
|
||||
completion_ids = [
|
||||
torch.tensor(ids, device=device) for ids in completion_ids_list
|
||||
]
|
||||
completion_mask = [
|
||||
torch.ones_like(ids, dtype=torch.long) for ids in completion_ids
|
||||
]
|
||||
completion_ids = pad(
|
||||
completion_ids, padding_value=trainer.pad_token_id, padding_side="right"
|
||||
)
|
||||
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
|
||||
|
||||
# Sampling logprobs from agent (used for IS correction)
|
||||
sampling_logps = [
|
||||
torch.tensor(lp, dtype=torch.float32, device=device) for lp in logprobs_list
|
||||
]
|
||||
sampling_per_token_logps = pad(
|
||||
sampling_logps, padding_value=0.0, padding_side="right"
|
||||
)
|
||||
|
||||
# env_mask as tool_mask (1=model tokens, 0=tool tokens)
|
||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||
|
||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||
for i, inp in enumerate(inputs):
|
||||
# Each input gets rewards for its num_generations rollouts
|
||||
start = i * self._num_generations
|
||||
end = start + self._num_generations
|
||||
inp["env_reward"] = rewards_list[start:end]
|
||||
|
||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||
expanded_inputs = []
|
||||
for inp in inputs:
|
||||
for g in range(self._num_generations):
|
||||
expanded_inp = dict(inp)
|
||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||
expanded_inputs.append(expanded_inp)
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Build total token count
|
||||
num_items_in_batch = completion_mask.sum()
|
||||
|
||||
# Build output dict (same shape as _generate_only)
|
||||
output = {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"num_items_in_batch": num_items_in_batch,
|
||||
"advantages": torch.zeros(completion_ids.size(0), device=device),
|
||||
"sampling_per_token_logps": sampling_per_token_logps,
|
||||
"tool_mask": tool_mask,
|
||||
# Deferred scoring markers
|
||||
"_pending_policy_logps": True,
|
||||
"_deferred_inputs": expanded_inputs,
|
||||
"_deferred_prompts": [inp.get("prompt", "") for inp in expanded_inputs],
|
||||
"_deferred_completions": completions,
|
||||
"_deferred_completion_ids_list": completion_ids_list,
|
||||
"_rank0_only": _rank0_only,
|
||||
}
|
||||
|
||||
return RolloutDataset(output)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user