Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
8495c79fb1 properly handles kernels repo type 2026-04-23 14:56:16 -04:00
Wing Lian
9a0d3016df first pass at build and deploy scattermoe-lora kernel 2026-04-22 01:10:01 -04:00
73 changed files with 744 additions and 4359 deletions

View File

@@ -31,11 +31,10 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
# Install axolotl + dev and test dependencies
# Install axolotl + dev and test dependencies from lockfile
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install
# test

View File

@@ -30,6 +30,14 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -160,6 +168,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -18,6 +18,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -174,6 +180,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
| Method | Config Key | When to Use |
|--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |

View File

@@ -1 +1 @@
0.16.2.dev0
0.16.0.dev0

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="optimizers,ray"; \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -58,3 +58,19 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="optimizers,ray"; \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -38,3 +38,20 @@ RUN uv pip install packaging setuptools wheel psutil \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo`
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
- Overfitting → `rl: ipo`
- VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto`

View File

@@ -77,9 +77,8 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
#### Remote Hosts
@@ -219,9 +218,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
You will now be in the container. Next, install Axolotl with dev dependencies:
```bash
uv venv --no-project --relocatable
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
### Attach To Container

View File

@@ -10,16 +10,13 @@ This section describes the different Docker images that are released by AxolotlA
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
### Switch to the `-uv` images
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
:::
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
::: {.callout-tip}
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
:::
## Base
@@ -88,7 +85,7 @@ Tags examples:
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20260315-py3.11-cu128-2.9.1`
- `0.16.1`
- `0.12.0`
## Cloud

View File

@@ -57,7 +57,7 @@ description: Frequently asked questions
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.9.1
- PyTorch ≥2.9.0
## Installation {#sec-installation}
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash}
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation axolotl[deepspeed]
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
### Edge/Development Build {#sec-edge-build}
@@ -49,11 +49,12 @@ For the latest features between releases:
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
uv sync --extra flash-attn --extra deepspeed
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]'
```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker}
```{.bash}
@@ -131,11 +132,11 @@ source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
uv venv --no-project --relocatable
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[deepspeed]
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
## Using pip (Alternative) {#sec-pip}
@@ -150,13 +151,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[deepspeed]
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[deepspeed]'
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -320,10 +320,8 @@ The input format is a simple JSON input with customizable fields based on the ab
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: dpo
dpo_loss_type: ["ipo"]
rl: ipo
```
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
### ORPO

View File

@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl>=0.16.1'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -36,7 +36,12 @@
"id": "msOCO4NRmRLa"
},
"outputs": [],
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
"source": [
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
},
{
"cell_type": "markdown",

View File

@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -9,8 +9,8 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -13,8 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -9,11 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,8 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl>=0.16.1'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,7 +13,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl>=0.16.1'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:

View File

@@ -11,8 +11,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Please install the below.

View File

@@ -12,7 +12,7 @@ requires-python = ">=3.10"
dependencies = [
# Core ML stack
"torch>=2.9.1",
"torch>=2.6.0",
"packaging==26.0",
"huggingface_hub>=1.1.7",
"peft>=0.19.1,<0.20.0",
@@ -79,7 +79,7 @@ dependencies = [
# Platform-specific (Linux only)
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
"triton>=3.4.0 ; sys_platform != 'darwin'",
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",

View File

@@ -0,0 +1,479 @@
#!/usr/bin/env python3
"""Build a disposable Hugging Face Kernel Hub package for ScatterMoE LoRA.
This script does not move or edit the in-tree Axolotl kernel sources. It copies
``src/axolotl/integrations/kernels/libs/scattermoe_lora`` into an ignored
build directory and emits a universal HF kernels project that can be pushed to
the Hub.
"""
from __future__ import annotations
import argparse
import fnmatch
import hashlib
import json
import os
import shutil
import subprocess
import sys
from importlib import metadata
from pathlib import Path
PACKAGE_NAME = "scattermoe_lora"
BUILD_VARIANT = "torch-universal"
DEFAULT_REPO_ID = "kernels-community/scattermoe-lora"
HF_REPO_TYPE = "kernel"
HF_KERNEL_URL_PREFIX = "https://hf.co/kernels"
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_SOURCE_DIR = (
REPO_ROOT / "src" / "axolotl" / "integrations" / "kernels" / "libs" / PACKAGE_NAME
)
DEFAULT_OUTPUT_DIR = REPO_ROOT / "build" / "hf-kernels" / PACKAGE_NAME
EXCLUDED_DIRS = {
"__pycache__",
".mypy_cache",
".pytest_cache",
".ruff_cache",
}
EXCLUDED_FILE_PATTERNS = {
"*.pyc",
"*.pyo",
"*.so",
".DS_Store",
}
TEXT_REPLACEMENTS = {
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import": (
"from .selective_dequant import"
),
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import": (
"from .selective_dequant_kernel import"
),
"from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import": (
"from .ops import"
),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Copy Axolotl's ScatterMoE LoRA Triton kernels into a disposable "
"HF Kernel Hub universal package."
)
)
parser.add_argument(
"--source-dir",
type=Path,
default=DEFAULT_SOURCE_DIR,
help=f"ScatterMoE LoRA source package to copy. Default: {DEFAULT_SOURCE_DIR}",
)
parser.add_argument(
"--output-dir",
type=Path,
default=DEFAULT_OUTPUT_DIR,
help=f"Destination build/dist directory. Default: {DEFAULT_OUTPUT_DIR}",
)
parser.add_argument(
"--repo-id",
default=DEFAULT_REPO_ID,
help=f"HF Hub repo id to write into build.toml. Default: {DEFAULT_REPO_ID}",
)
parser.add_argument(
"--version",
type=int,
default=1,
help="Kernel major version written to build.toml and metadata.json.",
)
parser.add_argument(
"--force",
action="store_true",
help="Delete the output directory first if it already exists.",
)
parser.add_argument(
"--no-source-layout",
action="store_true",
help="Only write the shippable build/ tree, not torch-ext/ sources.",
)
parser.add_argument(
"--upload",
action="store_true",
help=(
"Upload the generated universal kernel package with huggingface_hub. "
"This bypasses kernel-builder and is intended for pure Python/Triton "
"universal kernels."
),
)
parser.add_argument(
"--private",
action="store_true",
help="Create the HF Hub repo as private when used with --upload.",
)
parser.add_argument(
"--skip-version-branch",
action="store_true",
help="With --upload, only upload main and skip the v<version> branch.",
)
return parser.parse_args()
def should_skip_file(path: Path) -> bool:
return any(
fnmatch.fnmatch(path.name, pattern) for pattern in EXCLUDED_FILE_PATTERNS
)
def iter_source_files(source_dir: Path) -> list[Path]:
files: list[Path] = []
for root, dirs, filenames in os.walk(source_dir):
dirs[:] = sorted(d for d in dirs if d not in EXCLUDED_DIRS)
for filename in sorted(filenames):
path = Path(root) / filename
if not should_skip_file(path):
files.append(path)
return files
def content_hash(source_dir: Path) -> str:
digest = hashlib.sha1()
for path in iter_source_files(source_dir):
rel = path.relative_to(source_dir).as_posix()
digest.update(rel.encode("utf-8"))
digest.update(b"\0")
digest.update(path.read_bytes())
digest.update(b"\0")
return digest.hexdigest()[:10]
def git_revision() -> str:
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
cwd=REPO_ROOT,
check=True,
capture_output=True,
text=True,
)
except (OSError, subprocess.CalledProcessError):
return "unknown"
return result.stdout.strip() or "unknown"
def transform_python_source(text: str, rel_path: Path, op_namespace: str) -> str:
for old, new in TEXT_REPLACEMENTS.items():
text = text.replace(old, new)
if rel_path.as_posix() == "gemma4_experts.py":
text = text.replace(
" from axolotl.integrations.kernels.constants import resolve_experts_class",
(
" raise RuntimeError(\n"
' "patch_gemma4_scattermoe is only available from the in-tree Axolotl "\n'
' "integration. Use register_scattermoe_experts() with the standalone "\n'
' "HF kernel package."\n'
" )"
),
)
return text.replace("scattermoe::", f"{op_namespace}::")
def copy_package(source_dir: Path, package_dir: Path, op_namespace: str) -> None:
for source in iter_source_files(source_dir):
rel_path = source.relative_to(source_dir)
destination = package_dir / rel_path
destination.parent.mkdir(parents=True, exist_ok=True)
if source.suffix == ".py":
text = source.read_text(encoding="utf-8")
text = transform_python_source(text, rel_path, op_namespace)
destination.write_text(text, encoding="utf-8")
else:
shutil.copy2(source, destination)
write_ops_module(package_dir / "_ops.py", op_namespace)
def write_ops_module(path: Path, op_namespace: str) -> None:
path.write_text(
"\n".join(
[
"import torch",
"",
f"ops = torch.ops.{op_namespace}",
"",
"",
"def add_op_namespace_prefix(op_name: str) -> str:",
f' return f"{op_namespace}::{{op_name}}"',
"",
]
),
encoding="utf-8",
)
def write_build_toml(path: Path, repo_id: str, version: int) -> None:
lines = [
"[general]",
f'name = "{PACKAGE_NAME}"',
"universal = true",
f"version = {version}",
"",
]
if repo_id:
lines.extend(
[
"[general.hub]",
f'repo-id = "{repo_id}"',
"",
]
)
path.write_text("\n".join(lines), encoding="utf-8")
def write_flake(path: Path) -> None:
path.write_text(
"""{
description = "Flake for scattermoe_lora kernel";
inputs = {
builder.url = "github:huggingface/kernels";
};
outputs =
{
self,
builder,
}:
builder.lib.genKernelFlakeOutputs {
inherit self;
path = ./.;
};
}
""",
encoding="utf-8",
)
def write_readme(path: Path, repo_id: str, source_hash: str, op_namespace: str) -> None:
repo_display = repo_id or "<your-org>/scattermoe-lora"
path.write_text(
f"""---
library_name: kernels
license: apache-2.0
tags:
- kernel
- kernels
---
# ScatterMoE LoRA
Standalone Hugging Face Kernel Hub package for Axolotl's ScatterMoE LoRA Triton kernels.
This package is generated from Axolotl's in-tree `scattermoe_lora` sources and is exported as a universal kernel because the implementation is Python/Triton rather than a precompiled C++/CUDA extension.
```python
from kernels import get_kernel
scattermoe_lora = get_kernel("{repo_display}")
```
Export metadata:
- source package: `src/axolotl/integrations/kernels/libs/scattermoe_lora`
- source revision: `{git_revision()}`
- source content hash: `{source_hash}`
- torch custom op namespace: `{op_namespace}`
The generated `build/torch-universal/{PACKAGE_NAME}` directory is the shippable Hub artifact. `torch-ext/{PACKAGE_NAME}` is included so `kernel-builder build-and-copy` can regenerate the universal build tree if desired.
""",
encoding="utf-8",
)
def write_metadata(path: Path, version: int) -> None:
path.write_text(
json.dumps({"version": version}, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
def prepare_output_dir(output_dir: Path, force: bool) -> None:
if output_dir.exists():
if not force:
raise FileExistsError(
f"{output_dir} already exists. Re-run with --force to replace it."
)
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
def build_package(args: argparse.Namespace) -> Path:
source_dir = args.source_dir.resolve()
output_dir = args.output_dir.resolve()
if not source_dir.is_dir():
raise FileNotFoundError(f"source package does not exist: {source_dir}")
if not (source_dir / "__init__.py").is_file():
raise FileNotFoundError(f"source package is missing __init__.py: {source_dir}")
source_hash = content_hash(source_dir)
op_namespace = f"_{PACKAGE_NAME}_{source_hash}"
prepare_output_dir(output_dir, args.force)
write_build_toml(output_dir / "build.toml", args.repo_id, args.version)
write_flake(output_dir / "flake.nix")
write_readme(output_dir / "README.md", args.repo_id, source_hash, op_namespace)
if not args.no_source_layout:
copy_package(source_dir, output_dir / "torch-ext" / PACKAGE_NAME, op_namespace)
build_package_dir = output_dir / "build" / BUILD_VARIANT / PACKAGE_NAME
copy_package(source_dir, build_package_dir, op_namespace)
write_metadata(build_package_dir.parent / "metadata.json", args.version)
return output_dir
def upload_package(args: argparse.Namespace, output_dir: Path) -> None:
if not args.repo_id:
raise ValueError("--repo-id is required when using --upload")
try:
from huggingface_hub import HfApi, constants as hf_constants
except ImportError as exc:
raise RuntimeError(
"--upload requires huggingface_hub. Install it or run the upload "
"manually with the Hugging Face CLI."
) from exc
try:
hub_version = metadata.version("huggingface_hub")
except metadata.PackageNotFoundError:
hub_version = "unknown"
accepted_repo_types = getattr(
hf_constants,
"REPO_TYPES_WITH_KERNEL",
getattr(hf_constants, "REPO_TYPES", ()),
)
if HF_REPO_TYPE not in accepted_repo_types:
raise RuntimeError(
"Your huggingface_hub installation does not support "
f"repo_type={HF_REPO_TYPE!r} (found huggingface_hub {hub_version}). "
f"Upgrade this interpreter with: {sys.executable} -m pip install --upgrade "
"'huggingface_hub>=1.10.0'"
)
# huggingface_hub 1.11.0 has partial kernel support: create_repo accepts
# "kernel", but upload_folder/create_commit still validate against the
# older REPO_TYPES list. Extend it in-process so those helpers use the
# /api/kernels/... endpoints until upstream broadens that check.
if HF_REPO_TYPE not in hf_constants.REPO_TYPES:
hf_constants.REPO_TYPES.append(HF_REPO_TYPE)
api = HfApi()
try:
repo_id = api.create_repo(
repo_id=args.repo_id,
repo_type=HF_REPO_TYPE,
private=args.private,
exist_ok=True,
).repo_id
except ValueError as exc:
if "Invalid repo type" in str(exc):
raise RuntimeError(
"huggingface_hub rejected repo_type='kernel'. "
f"This usually means the command is running with an older Hub "
f"client than expected (found huggingface_hub {hub_version} at "
f"{sys.executable}). Upgrade with: {sys.executable} -m pip "
"install --upgrade 'huggingface_hub>=1.10.0'"
) from exc
raise
delete_patterns = [
"build/**",
"torch-ext/**",
"build.toml",
"flake.nix",
"README.md",
]
api.upload_folder(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
folder_path=output_dir,
revision="main",
delete_patterns=delete_patterns,
commit_message="Upload ScatterMoE LoRA universal kernel",
)
print(f"Uploaded main branch: {HF_KERNEL_URL_PREFIX}/{repo_id}")
if args.skip_version_branch:
return
version_branch = f"v{args.version}"
api.create_branch(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
branch=version_branch,
revision="main",
exist_ok=True,
)
api.upload_folder(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
folder_path=output_dir,
revision=version_branch,
delete_patterns=delete_patterns,
commit_message=f"Upload ScatterMoE LoRA universal kernel {version_branch}",
)
print(
f"Uploaded version branch: "
f"{HF_KERNEL_URL_PREFIX}/{repo_id}/tree/{version_branch}"
)
def main() -> int:
args = parse_args()
try:
output_dir = build_package(args)
if args.upload:
upload_package(args, output_dir)
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(f"Wrote ScatterMoE LoRA HF kernel package to: {output_dir}")
print(f"Shippable artifact: {output_dir / 'build' / BUILD_VARIANT / PACKAGE_NAME}")
if args.upload:
print(f'Load it with: get_kernel("{args.repo_id}", version={args.version})')
print(f"Uploaded as Hugging Face repo_type={HF_REPO_TYPE!r}.")
return 0
print("Next step:")
print(" upload this universal Python/Triton kernel directly:")
print(
f" python3 {Path(__file__).as_posix()} "
f"--repo-id {args.repo_id} --force --upload"
)
if shutil.which("kernel-builder") is None:
print(" optional: install kernel-builder for full Nix-based builds:")
print(
" curl -fsSL "
"https://raw.githubusercontent.com/huggingface/kernels/main/install.sh "
"| bash"
)
else:
print(" optional: upload with kernel-builder:")
print(f" cd {output_dir}")
print(" kernel-builder build-and-upload")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple

View File

@@ -228,47 +228,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return training_args, trainer_kwargs
def build_collator(self, **kwargs):
"""Build a data collator for preference-tuning trainers.
Returns None for RL types that provide their own collator (e.g. GRPO,
KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO
returns an ``AxolotlDPODataCollatorWithPadding`` when
``pad_to_multiple_of`` is set, otherwise None (so the trainer
falls back to the TRL default).
"""
if self.cfg.rl not in (
RLType.DPO,
RLType.IPO,
RLType.ORPO,
RLType.SIMPO,
):
return None
pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None)
if not pad_to_multiple_of:
return None
from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding
LOG.info(
f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of="
f"{pad_to_multiple_of}"
)
is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False)
return AxolotlDPODataCollatorWithPadding(
pad_token_id=self.tokenizer.pad_token_id,
is_encoder_decoder=is_enc_dec,
pad_to_multiple_of=pad_to_multiple_of,
**kwargs,
)
def build(self, total_num_steps):
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
if (data_collator := self.build_collator()) is not None:
trainer_kwargs["data_collator"] = data_collator
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if (

View File

@@ -20,16 +20,8 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.DPO:
if cfg.dpo_loss_type is not None:
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
if cfg.dpo_loss_weights is not None:
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = ["ipo"]
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing

View File

@@ -242,85 +242,6 @@ class ProducerConfig:
)
class _GroupShardedSampler:
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
``RepeatSampler`` yields ``num_generations`` consecutive copies of
each prompt, forming a GRPO group. For distributed training each
rank must see a disjoint slice of prompts (otherwise every rank
dogpiles on the first 1/world_size of the batch) while keeping each
group intact on a single rank so advantage normalization sees all
peer generations.
``accelerator.prepare(DataLoader)`` does not handle this correctly
for custom samplers with ``split_batches=False`` (the default): it
leaves the sampler alone and every rank replays identical indices.
This wrapper fixes that by consuming the inner sampler's full
output, chunking it into ``num_generations``-sized groups, and
round-robining whole groups across ranks.
Intended to be used ONLY when distributed training is active
(``num_replicas > 1``); for single-rank it is a no-op but still
correct.
"""
def __init__(
self,
inner: Any,
num_generations: int,
rank: int,
num_replicas: int,
):
if num_generations < 1:
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if not (0 <= rank < num_replicas):
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
self.inner = inner
self.num_generations = num_generations
self.rank = rank
self.num_replicas = num_replicas
def __iter__(self):
all_indices = list(self.inner)
if len(all_indices) % self.num_generations != 0:
raise ValueError(
f"inner sampler yielded {len(all_indices)} indices, "
f"not a multiple of num_generations={self.num_generations}"
)
# Chunk the flat index sequence into groups of num_generations
# consecutive indices. ``RepeatSampler`` guarantees that each
# group contains num_generations copies of the same prompt id.
groups = [
all_indices[i : i + self.num_generations]
for i in range(0, len(all_indices), self.num_generations)
]
# Round-robin whole groups across ranks. Round-robin (vs.
# contiguous chunking) preserves approximate shuffled order on
# each rank even when the group count is small relative to the
# world size.
for group in groups[self.rank :: self.num_replicas]:
yield from group
def __len__(self):
try:
inner_len = len(self.inner)
except TypeError:
# Non-sized inner sampler — we can't know the per-rank
# length without materializing. Return 0 as a hint that the
# DataLoader should fall back to iteration.
return 0
total_groups = inner_len // self.num_generations
# Ceiling division for the trailing groups that don't divide
# evenly — extra groups go to the first ``total_groups %
# num_replicas`` ranks, matching the round-robin above.
my_groups = (
total_groups + self.num_replicas - self.rank - 1
) // self.num_replicas
return my_groups * self.num_generations
class DataProducer(ABC):
"""Abstract base class for online data producers.
@@ -635,34 +556,6 @@ class GRPODataProducer(BaseDataProducer):
seed=self._seed,
)
# Shard the sampler across distributed ranks so each rank sees
# a disjoint slice of prompts. ``RepeatSampler`` groups each
# prompt with ``num_generations`` consecutive copies — our
# wrapper round-robins WHOLE groups across ranks so all
# generations of a given prompt stay on the same rank (needed
# for GRPO advantage normalization within a group).
#
# Without this, ``accelerator.prepare(dl)`` with the default
# ``split_batches=False`` leaves the custom sampler alone, so
# every rank iterates the identical index sequence and the
# cluster dogpiles on the first 1/world_size of the prompts.
num_replicas = max(1, trainer.accelerator.num_processes)
if num_replicas > 1:
sampler = _GroupShardedSampler(
inner=sampler,
num_generations=self._num_generations,
rank=trainer.accelerator.process_index,
num_replicas=num_replicas,
)
logger.info(
"[RANK:%d] _GroupShardedSampler active "
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
trainer.accelerator.process_index,
num_replicas,
self._num_generations,
self._generation_batch_size,
)
# Use identity collator (same as stock GRPOTrainer)
def _identity(x):
return x
@@ -681,11 +574,12 @@ class GRPODataProducer(BaseDataProducer):
rank=trainer.args.process_index,
),
)
# Skip accelerator.prepare — we're handling per-rank sharding
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
# otherwise try to wrap the DataLoader with its own sharding
# logic which does not understand our group structure.
self._prompt_dl = dl
self._prompt_dl = trainer.accelerator.prepare(dl)
# Don't let accelerator track this dataloader
acc_dls = trainer.accelerator._dataloaders
if self._prompt_dl in acc_dls:
acc_dls.remove(self._prompt_dl)
self._prompt_iter = iter(self._prompt_dl)
@@ -1209,22 +1103,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
This is the canonical sync trigger and runs in BOTH async and
synchronous modes from ``_prepare_inputs_with_data_producer`` /
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
patch is a parallel backup for non-data-producer paths (vanilla
GRPO without NeMo Gym), where the data producer is bypassed
entirely and TRL's stock generate-then-sync flow is used instead.
"""
if not self.use_vllm:
if not (self.use_vllm and self.args.async_prefetch):
return
step = self.state.global_step
# Default to syncing every step when no interval is configured —
# otherwise ``step % None`` would TypeError, and the previous
# behavior of crashing on the first sync was strictly worse than
# the standard "sync every optimizer step".
interval = self.args.vllm_sync_interval or 1
interval = self.args.vllm_sync_interval
if step != self._last_synced_step and step % interval == 0:
if step == 0:
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
@@ -1319,42 +1202,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Permanently replace vllm_generation.sync_weights with our custom
# sync to avoid merge_adapter (fails on FP8 / races with training).
#
# The design has two modes that have to be threaded carefully:
#
# - Async prefetch ON: BG generation thread can't safely call
# sync_weights mid-rollout (it races with the trainer's optimizer
# step and can corrupt weights). We no-op the stock sync hook and
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
# optimizer step on the main thread.
#
# - Async prefetch OFF (synchronous mode): TRL's stock
# ``_generate_single_turn`` calls ``sync_weights`` once per step
# boundary. There's no BG thread to race with, and
# ``_maybe_sync_vllm_weights`` short-circuits with
# ``if not async_prefetch: return``, so we MUST wire the stock
# hook directly to our LoRA sync helper — otherwise nothing ever
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
# keeps serving the base model, every rollout in every group
# produces identical outputs, advantages are zero, optimizer
# step gets skipped, repeat).
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
# handles the sync with proper interval tracking.
if not getattr(self, "_patched_sync_weights", False):
if self.use_vllm and hasattr(self, "vllm_generation"):
if getattr(self.args, "vllm_lora_sync", False):
if getattr(self.args, "async_prefetch", False):
# Async: drive sync from main thread via
# _maybe_sync_vllm_weights instead.
self.vllm_generation.sync_weights = lambda: None
else:
# Sync mode: TRL's _generate_single_turn already
# calls sync_weights once per step boundary. Wire
# it directly to our LoRA filesystem sync helper.
sync_helper = self._sync_lora_adapter
def _lora_filesystem_sync():
sync_helper()
self.vllm_generation.sync_weights = _lora_filesystem_sync
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
self.vllm_generation.sync_weights = lambda: None
self._patched_sync_weights = True
else:
from accelerate.utils import is_peft_model

View File

@@ -1,27 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Hatchery/Tinker remote training integration for Axolotl.
Routes axolotl's preprocessed data to a remote training API (Tinker or
Hatchery) instead of running forward/backward locally. The remote
service handles model weights, LoRA adapters, and gradient updates.
"""
from .args import HatcheryArgs, HatcheryConfig
from .plugin import HatcheryPlugin
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
# Usage:
# plugins:
# - axolotl.integrations.hatchery.HatcheryPlugin
#
# hatchery:
# backend: tinker # or "hatchery"
# lora_rank: 32
# loss_fn: cross_entropy # SFT
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
#
# learning_rate: 1e-4 # top-level, not under hatchery:

View File

@@ -1,62 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Pydantic config schema for the Hatchery integration."""
from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
class HatcheryConfig(BaseModel):
"""Nested config under `hatchery:` in the axolotl YAML.
Only contains hatchery-specific settings. Standard training params
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
gradient_accumulation_steps) are read from axolotl's top-level config.
"""
# Backend & connection
backend: Literal["tinker", "hatchery"] = "tinker"
base_url: Optional[str] = None
api_key: Optional[str] = None
project_id: Optional[str] = None
# LoRA config sent to remote
lora_rank: int = Field(32, ge=1, le=256)
train_attn: bool = True
train_mlp: bool = True
train_unembed: bool = True
# Loss function
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
"cross_entropy"
)
loss_fn_config: Optional[dict[str, Any]] = None
# Pipelining: submit next batch before awaiting previous result
pipeline: bool = True
# Sampling params (for RL flows)
max_sample_tokens: int = 256
sample_temperature: float = 1.0
num_samples: int = 4
# Reward functions (for RL) — list of fully qualified names
reward_funcs: Optional[list[str]] = None
# Checkpointing
save_steps: Optional[int] = None
save_name_prefix: str = "checkpoint"
# Timeout per future (seconds)
future_timeout: float = 600.0
class HatcheryArgs(BaseModel):
"""Top-level mixin that adds the nested `hatchery:` field."""
hatchery: Optional[HatcheryConfig] = None

View File

@@ -1,160 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
Both Tinker and Hatchery expect the client to apply the causal LM shift:
Original tokens: [t0, t1, t2, ..., t_{L-1}]
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
"""
from __future__ import annotations
from typing import Any
import torch
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
"""Serialize a tensor to the TensorData wire dict."""
flat = t.detach().cpu().flatten()
dtype_map = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
}
return {
"dtype": dtype_map.get(flat.dtype, "float32"),
"shape": list(t.shape),
"data": flat.tolist(),
}
def _make_datum(
tokens: list[int],
loss_fn_inputs: dict[str, torch.Tensor],
) -> dict[str, Any]:
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
return {
"model_input": {
"chunks": [{"type": "encoded_text", "tokens": tokens}],
},
"loss_fn_inputs": {
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
},
}
def datums_to_tinker(datums: list[dict[str, Any]]):
"""Wrap plain-dict datums into tinker.types.Datum objects.
Both the Tinker SDK and updated Hatchery client accept these.
"""
import tinker.types as tt
result = []
for d in datums:
tokens = d["model_input"]["chunks"][0]["tokens"]
tinker_inputs = {}
for key, wire in d["loss_fn_inputs"].items():
tinker_inputs[key] = tt.TensorData(
data=wire["data"],
dtype=wire["dtype"],
shape=wire["shape"],
)
result.append(
tt.Datum(
model_input=tt.ModelInput.from_ints(tokens),
loss_fn_inputs=tinker_inputs,
)
)
return result
def batch_to_datums_sft(
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
ids = ids[:seq_len]
lbl = lbl[:seq_len]
model_tokens = ids[:-1].tolist()
shifted_labels = lbl[1:]
target_tokens = shifted_labels.clone()
weights = (shifted_labels != -100).float()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"weights": weights,
},
)
)
return datums
def batch_to_datums_rl(
input_ids: torch.Tensor,
labels: torch.Tensor,
logprobs: torch.Tensor,
advantages: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
else:
seq_len = ids.size(0)
ids = ids[:seq_len]
lbl = lbl[:seq_len]
lp = logprobs[i, :seq_len]
adv = advantages[i, :seq_len]
model_tokens = ids[:-1].tolist()
target_tokens = lbl[1:].clone()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"logprobs": lp[1:],
"advantages": adv[1:],
},
)
)
return datums

View File

@@ -1,87 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
Creates a dataset with chat-formatted prompts that include
a hidden gold answer tag for the reward function.
Run:
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
"""
import os
import re
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
def extract_boxed(text: str) -> str:
match = re.search(r"\\boxed\{", text)
if not match:
return ""
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else ""
def main():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
level = os.environ.get("MATH_LEVEL", "Level 1")
filtered_rows = [x for x in ds if x["level"] == level]
print(f"{level} algebra: {len(filtered_rows)} problems")
rows = []
for prob in filtered_rows:
gold = extract_boxed(prob["solution"])
if not gold:
continue
# Format as chat prompt with hidden gold tag
prompt = (
f"Solve the following math problem. "
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
f"{prob['problem']}"
f"<|gold|>{gold}<|/gold|>"
)
# Tokenize the prompt
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
rows.append(
{
"input_ids": prompt_ids,
"labels": [-100] * len(prompt_ids),
"attention_mask": [1] * len(prompt_ids),
}
)
out = Dataset.from_list(rows)
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
out.save_to_disk(out_dir)
print(f"Saved {len(out)} examples to {out_dir}")
if rows:
print(
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
f"-{max(len(r['input_ids']) for r in rows)}"
)
if __name__ == "__main__":
main()

View File

@@ -1,47 +0,0 @@
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
#
# Prep:
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
#
# Run:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: importance_sampling
max_sample_tokens: 2048
sample_temperature: 0.7
num_samples: 4
pipeline: true
save_steps: 5
reward_funcs:
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
datasets:
- path: ./data/math_rl_level1
ds_type: arrow
type: completion
sequence_len: 2048
learning_rate: 5.0e-5
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
max_steps: 10
num_epochs: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-rl-math

View File

@@ -1,42 +0,0 @@
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
#
# Usage:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: cross_entropy
pipeline: true
save_steps: 10
datasets:
- path: TeichAI/kimi-k2-thinking-1000x
split: train[:50]
type: chat_template
chat_template: qwen3
split_thinking: true
chat_template: qwen3
sequence_len: 2048
learning_rate: 3.0e-4
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
num_epochs: 1
max_steps: 20
micro_batch_size: 2
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-sft

View File

@@ -1,147 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
from __future__ import annotations
import torch
from peft import PeftModel
from transformers import AutoConfig, PreTrainedModel, Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HatcheryPlugin(BasePlugin):
"""Plugin that replaces local training with remote API calls.
Activated by adding to the axolotl YAML:
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker # or "hatchery"
lora_rank: 32
loss_fn: cross_entropy
# ... see HatcheryConfig for full options
"""
def get_input_args(self) -> str:
return "axolotl.integrations.hatchery.args.HatcheryArgs"
def register(self, cfg: dict):
"""Auto-set config values needed for remote training."""
if cfg.get("remove_unused_columns") is None:
cfg["remove_unused_columns"] = False
def pre_model_load(self, cfg: DictDefault):
"""Replace model loading with a tiny stub."""
hcfg = cfg.hatchery or {}
backend = (
hcfg.get("backend", "tinker")
if isinstance(hcfg, dict)
else getattr(hcfg, "backend", "tinker")
)
LOG.info(
f"Hatchery plugin active: training dispatched to remote "
f"{backend} API. Skipping local model weight loading."
)
from axolotl.loaders import ModelLoader
def _stub_build_model(loader_self) -> bool:
base_model = loader_self.cfg.base_model
LOG.info(f"Skipping model weight loading for: {base_model}")
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
)
class _Stub(PreTrainedModel):
config_class = type(config)
_no_split_modules: list[str] = []
supports_gradient_checkpointing = False
def __init__(self, cfg):
super().__init__(cfg)
vocab_size = getattr(cfg, "vocab_size", 32000)
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
pass
def get_output_embeddings(self):
return None
loader_self.model = _Stub(config)
return True
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Return the appropriate remote trainer class."""
hcfg = cfg.hatchery
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
from .rl_trainer import HatcheryRLTrainer
return HatcheryRLTrainer
from .trainer import HatcheryTrainer
return HatcheryTrainer
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
model._hatchery_remote = True
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
LOG.info(
"Hatchery: skipping local model save (weights are on remote API). "
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
)
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Inject hatchery config + axolotl training params into the trainer."""
from .args import HatcheryConfig
from .rl_trainer import HatcheryRLTrainer
from .trainer import HatcheryTrainer
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
return
hcfg = cfg.hatchery
if isinstance(hcfg, dict):
hatchery_config = HatcheryConfig(**hcfg)
elif hcfg is None:
hatchery_config = HatcheryConfig()
else:
hatchery_config = hcfg
trainer.hatchery_args = hatchery_config
trainer._base_model_name = cfg.base_model
# Pull standard training params from axolotl config so they
# don't need to be duplicated under hatchery:
trainer._optim_params = {
"learning_rate": cfg.learning_rate
if cfg.learning_rate is not None
else 1e-4,
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
"grad_clip_norm": cfg.max_grad_norm
if cfg.max_grad_norm is not None
else 0.0,
}

View File

@@ -1,3 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0

View File

@@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Math reward function for hendrycks_math GRPO training.
Uses math_verify for robust answer comparison. Falls back to
exact string match of \\boxed{} content only when math_verify
is unavailable.
"""
from __future__ import annotations
import logging
import re
LOG = logging.getLogger(__name__)
def extract_boxed(text: str) -> str | None:
"""Extract \\boxed{...} answer handling nested braces."""
match = re.search(r"\\boxed\{", text)
if not match:
return None
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else None
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
"""Score completions by checking if \\boxed{} answer matches the gold answer.
The gold answer is extracted from the prompt (appended as a hidden
tag by the dataset preprocessing). Format:
... <|gold|>ANSWER<|/gold|>
"""
rewards = []
for prompt, completion in zip(prompts, completions, strict=True):
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
if not gold_match:
rewards.append(0.0)
continue
gold_answer = gold_match.group(1).strip()
pred_answer = extract_boxed(completion)
if pred_answer is None:
rewards.append(0.0)
continue
verified = None
try:
from math_verify import parse, verify
gold_parsed = parse(gold_answer)
pred_parsed = parse(pred_answer)
verified = verify(gold_parsed, pred_parsed)
except Exception:
LOG.debug(
"math_verify unavailable or failed, using string fallback",
exc_info=True,
)
if verified is not None:
rewards.append(1.0 if verified else 0.0)
elif pred_answer.strip() == gold_answer.strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards

View File

@@ -1,409 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
Full RL loop per step:
1. Extract prompts from dataset batch
2. Sample N completions per prompt via remote SamplingClient
3. Score completions with local reward functions
4. Compute GRPO-style advantages (per-group normalization)
5. Send (prompt+completion, logprobs, advantages) as forward_backward
6. Optimizer step
"""
from __future__ import annotations
import importlib
import inspect
import re
import time
from typing import Any, Callable, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_rl, datums_to_tinker
from .trainer import _create_training_client
LOG = get_logger(__name__)
def _load_reward_func(fqn: str) -> Callable:
"""Load a reward function from a fully qualified name like 'module.func'."""
module_path = ".".join(fqn.split(".")[:-1])
func_name = fqn.split(".")[-1]
mod = importlib.import_module(module_path)
func = getattr(mod, func_name)
if len(inspect.signature(func).parameters) < 2:
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
return func
class HatcheryRLTrainer(AxolotlTrainer):
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
_reward_functions: list[Callable]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
self._reward_functions = []
def _ensure_reward_functions(self):
if self._reward_functions:
return
args = self.hatchery_args
if not args or not args.reward_funcs:
raise ValueError(
"No reward functions configured. Set hatchery.reward_funcs "
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
)
for fqn in args.reward_funcs:
self._reward_functions.append(_load_reward_func(fqn))
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
def _get_training_client(self):
if self._training_client is not None:
return self._training_client
self._training_client = _create_training_client(
self.hatchery_args, self._base_model_name
)
LOG.info(
f"Remote RL session created: backend={self.hatchery_args.backend}, "
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
)
return self._training_client
def _sample_completions(self, prompt_ids_list: list[list[int]]):
"""Sample completions for prompts via remote API."""
import tinker.types as tt
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
results = []
sc = tc.save_weights_and_get_sampling_client()
for prompt_ids in prompt_ids_list:
if hasattr(sc, "sampling_session_id"):
sample_result = sc.sample(
prompt_ids,
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
n=args.num_samples,
).result(timeout=args.future_timeout)
else:
mi = tt.ModelInput.from_ints(prompt_ids)
sp = tt.SamplingParams(
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
top_p=0.95,
top_k=-1,
)
sample_result = sc.sample(
prompt=mi,
num_samples=args.num_samples,
sampling_params=sp,
).result(timeout=args.future_timeout)
sequences = (
sample_result.sequences
if hasattr(sample_result, "sequences")
else sample_result.get("sequences", [])
)
for seq in sequences:
tokens = (
list(seq.tokens)
if hasattr(seq, "tokens")
else seq.get("tokens", [])
)
logprobs = (
list(seq.logprobs)
if hasattr(seq, "logprobs") and seq.logprobs
else seq.get("logprobs", [])
)
results.append(
{
"tokens": list(prompt_ids) + tokens,
"completion_tokens": tokens,
"logprobs": logprobs,
"prompt_len": len(prompt_ids),
}
)
return results
def _compute_rewards(
self, prompts: list[str], completions: list[str]
) -> list[float]:
total_rewards = [0.0] * len(completions)
for reward_fn in self._reward_functions:
rewards = reward_fn(prompts, completions)
for i, r in enumerate(rewards):
total_rewards[i] += r
return total_rewards
@staticmethod
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
advantages = []
for i in range(0, len(rewards), group_size):
group = rewards[i : i + group_size]
mean = sum(group) / len(group)
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
std = var**0.5 if var > 1e-8 else 1.0
advantages.extend([(r - mean) / std for r in group])
return advantages
def _do_optim_step(self):
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
self._ensure_reward_functions()
train_dataloader = self.get_train_dataloader()
num_train_epochs = int(self.args.num_train_epochs)
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
LOG.info(
f"Remote RL training: max_steps={max_steps}, "
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
tokenizer = self.processing_class
global_step = 0
total_loss = 0.0
total_reward = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
for batch in train_dataloader:
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
prompt_ids_batch = batch["input_ids"]
# Full prompt text (with gold tag) for reward scoring
prompt_texts = tokenizer.batch_decode(
prompt_ids_batch, skip_special_tokens=False
)
# Strip <|gold|>...<|/gold|> from token ids before
# sending to the model for sampling — the gold answer
# must only be visible to the local reward function.
sampling_prompts = []
for prompt_text in prompt_texts:
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
sampling_prompts.append(clean_ids)
# 1. Sample completions (without gold answer)
t0 = time.time()
samples = self._sample_completions(sampling_prompts)
t_sample = time.time() - t0
if not samples:
LOG.warning("No samples generated, skipping step")
continue
LOG.info(
f"Sampled {len(samples)} completions, "
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
)
# 2. Decode and score
completion_texts = [
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
for s in samples
]
sample_prompts = []
for prompt_text in prompt_texts:
sample_prompts.extend([prompt_text] * args.num_samples)
rewards = self._compute_rewards(sample_prompts, completion_texts)
# 3. GRPO advantages
advantages_list = self._compute_advantages(
rewards, group_size=args.num_samples
)
# 4. Build training data
all_datums = []
for i, sample in enumerate(samples):
full_tokens = sample["tokens"]
prompt_len = sample["prompt_len"]
seq_len = len(full_tokens)
input_ids = torch.tensor([full_tokens], dtype=torch.long)
labels = torch.full((1, seq_len), -100, dtype=torch.long)
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
logprobs_t = torch.zeros(1, seq_len)
if sample["logprobs"]:
lp = sample["logprobs"][: seq_len - prompt_len]
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
lp
)
adv_t = torch.zeros(1, seq_len)
adv_t[0, prompt_len:] = advantages_list[i]
all_datums.extend(
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
)
# 5. Forward backward (one datum at a time for memory) + optim
t0 = time.time()
tc = self._get_training_client()
step_loss = 0.0
for datum in all_datums:
fb_future = tc.forward_backward(
datums_to_tinker([datum]),
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
fb_result = fb_future.result(timeout=args.future_timeout)
if hasattr(fb_result, "metrics"):
step_loss += float(
(fb_result.metrics or {}).get("loss:sum", 0.0)
)
elif isinstance(fb_result, dict):
step_loss += float(
fb_result.get("metrics", {}).get("loss:sum", 0.0)
)
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
t_train = time.time() - t0
mean_reward = sum(rewards) / len(rewards)
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
global_step += 1
total_loss += step_loss
total_reward += mean_reward
self.state.global_step = global_step
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
LOG.info(
f"[step {global_step}/{max_steps}] "
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
f"sample={t_sample:.1f}s train={t_train:.1f}s "
f"{elapsed / global_step:.1f}s/step"
)
self.log(
{
"loss": step_loss,
"reward": mean_reward,
"accuracy": accuracy,
"mean_abs_advantage": mean_adv,
"learning_rate": self._optim_params["learning_rate"],
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
avg_reward = total_reward / max(global_step, 1)
LOG.info(
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
f"avg_reward={avg_reward:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={
"train_loss": avg_loss,
"train_reward": avg_reward,
"train_runtime": elapsed,
},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
)

View File

@@ -1,327 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote trainer that dispatches to Tinker or Hatchery API."""
from __future__ import annotations
import os
import time
from typing import Any, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_sft, datums_to_tinker
LOG = get_logger(__name__)
def _extract_loss(result) -> float:
"""Extract loss:sum from a forward_backward result.
Tinker's cross_entropy (and other losses) return the SUM of per-token
losses, not the mean. This is by design — it lets users control
normalization via the weights tensor. The trainer logs this raw sum;
users who want per-token loss should divide by number of active tokens.
"""
if hasattr(result, "metrics"):
metrics = result.metrics or {}
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
if isinstance(result, dict):
metrics = result.get("metrics", {})
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
return 0.0
def _create_training_client(args: HatcheryConfig, base_model: str):
"""Create a training client for either Tinker or Hatchery backend."""
if args.backend == "tinker":
import tinker
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
if not api_key:
raise ValueError(
"Tinker API key required. Set `hatchery.api_key` in config "
"or TINKER_API_KEY env var."
)
os.environ["TINKER_API_KEY"] = api_key
service = tinker.ServiceClient(project_id=args.project_id)
return service.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_mlp=args.train_mlp,
train_attn=args.train_attn,
train_unembed=args.train_unembed,
)
from hatchery.core.client import HatcheryClient
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
return client.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_attn=args.train_attn,
train_mlp=args.train_mlp,
train_unembed=args.train_unembed,
)
class HatcheryTrainer(AxolotlTrainer):
"""Trainer that sends preprocessed batches to a remote training API.
Replaces local forward/backward with remote API calls to Tinker or
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
chat templates, packing, etc.) but offloads compute to remote GPUs.
"""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
def _get_training_client(self):
"""Lazily create the remote training session."""
if self._training_client is not None:
return self._training_client
args = self.hatchery_args
if args is None:
raise RuntimeError(
"HatcheryTrainer.hatchery_args not set. "
"Ensure the HatcheryPlugin is registered."
)
base_model = self._base_model_name
if not base_model:
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
self._training_client = _create_training_client(args, base_model)
LOG.info(
f"Remote training session created: backend={args.backend}, "
f"model={base_model}, rank={args.lora_rank}"
)
return self._training_client
def _send_batch(self, batch: dict[str, torch.Tensor]):
"""Convert batch to datums and send forward_backward to remote.
Returns (future, n_active_tokens) where n_active_tokens counts
the completion tokens in this batch (for loss normalization).
"""
input_ids = batch["input_ids"]
labels = batch["labels"]
attention_mask = batch.get("attention_mask")
n_active = int((labels[:, 1:] != -100).sum().item())
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
send_datums = datums_to_tinker(datums)
future = tc.forward_backward(
send_datums,
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
return future, n_active
def _do_optim_step(self):
"""Send optimizer step to remote using axolotl's training params."""
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
"""Main training loop — sends batches to remote API."""
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
train_dataloader = self.get_train_dataloader()
num_batches = len(train_dataloader)
grad_accum = self.args.gradient_accumulation_steps
num_train_epochs = int(self.args.num_train_epochs)
steps_per_epoch = max(num_batches // grad_accum, 1)
max_steps = (
self.args.max_steps
if self.args.max_steps > 0
else steps_per_epoch * num_train_epochs
)
LOG.info(
f"Remote training: {num_batches} batches/epoch, "
f"{grad_accum} grad_accum, {max_steps} max steps, "
f"{num_train_epochs} epochs"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
global_step = 0
total_loss = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_epoch_begin(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
for batch_idx, batch in enumerate(train_dataloader):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
fb_future, n_active = self._send_batch(batch)
pending_fb_futures.append((fb_future, n_active))
accum_count += 1
if accum_count >= grad_accum:
step_loss_sum = 0.0
step_active = 0
for fut, n_act in pending_fb_futures:
result = fut.result(timeout=args.future_timeout)
step_loss_sum += _extract_loss(result)
step_active += n_act
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
step_loss = (
step_loss_sum / step_active
if step_active > 0
else step_loss_sum
)
global_step += 1
total_loss += step_loss
self.state.global_step = global_step
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
avg_loss = total_loss / global_step
LOG.info(
f"[step {global_step}/{max_steps}] "
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
f"active={step_active} "
f"{elapsed / global_step:.2f}s/step"
)
self.log(
{
"loss": step_loss,
"learning_rate": self._optim_params["learning_rate"],
"epoch": self.state.epoch,
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
if self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
LOG.info(
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
"""Save a checkpoint on the remote service."""
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryTrainer uses remote API; compute_loss should not be called."
)

View File

@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # recommended to reduce vram
torch_compile: True # torch>=2.6.0, recommended to reduce vram
datasets:
- path: ...

View File

@@ -110,36 +110,11 @@ class NemoGymDataProducer(GRPODataProducer):
item["agent_ref"] = full_item["agent_ref"]
dataset_items.append(item)
# NOTE: do NOT re-expand by num_generations here.
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
# yields ``num_generations`` consecutive copies of each unique
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
# num_generations)`` items — one entry per rollout. Expanding
# again here would fire ``num_generations^2`` rollouts per
# prompt per rank and make every step dogpile on a handful of
# tasks.
expanded_items = dataset_items
# Diagnostic: log what this rank is about to fire.
try:
import collections
iid_counts: collections.Counter[str | None] = collections.Counter()
for it in dataset_items:
iid_counts[
(it.get("responses_create_params", {}).get("metadata") or {}).get(
"instance_id"
)
] += 1
LOG.info(
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
trainer.accelerator.process_index,
len(dataset_items),
len(iid_counts),
list(iid_counts.most_common(5)),
)
except Exception:
pass
# Expand by num_generations (agent produces one rollout per call)
expanded_items = []
for item in dataset_items:
for _ in range(self._num_generations):
expanded_items.append(item)
# Call NeMo Gym agents
loop = asyncio.new_event_loop()
@@ -165,7 +140,6 @@ class NemoGymDataProducer(GRPODataProducer):
logprobs_list = []
rewards_list = []
num_turns_list: list[int] = []
for resp in responses:
parsed = _parse_agent_response(resp, eos_token_id)
prompt_ids_list.append(parsed["prompt_ids"])
@@ -173,7 +147,6 @@ class NemoGymDataProducer(GRPODataProducer):
env_mask_list.append(parsed["env_mask"])
logprobs_list.append(parsed["logprobs"])
rewards_list.append(parsed["reward"])
num_turns_list.append(parsed.get("num_turns", 0))
# Pad to tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -206,48 +179,22 @@ class NemoGymDataProducer(GRPODataProducer):
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
# Inject per-rollout reward + num_turns into each input. Since
# ``RepeatSampler`` already yields ``num_generations`` copies of
# each prompt, ``inputs`` has ONE entry per rollout (matching
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
# GRPO advantage normalization is the trainer's job downstream.
assert len(inputs) == len(rewards_list), (
f"rewards/inputs length mismatch: "
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
)
# Inject rewards into inputs so _compute_deferred_scores can use them
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
# Our passthrough reward_fn reads "env_reward" from kwargs.
for i, inp in enumerate(inputs):
inp["env_reward"] = rewards_list[i]
inp["num_turns"] = num_turns_list[i]
# Each input gets rewards for its num_generations rollouts
start = i * self._num_generations
end = start + self._num_generations
inp["env_reward"] = rewards_list[start:end]
# One expanded_input per rollout (already correct count because
# inputs has num_generations copies baked in by the sampler).
expanded_inputs = [dict(inp) for inp in inputs]
# Log rollout-level stats to wandb from rank 0. These are the
# true agent-side metrics (not the tokenized TRL view) — so
# num_turns reflects how many /run iterations each rollout
# actually took before finishing or hitting max_turns.
if is_main and num_turns_list:
try:
import wandb
if wandb.run is not None:
import statistics as _stats
nonzero = sum(1 for r in rewards_list if r > 0)
log_payload = {
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
"rollout/num_turns/min": float(min(num_turns_list)),
"rollout/num_turns/max": float(max(num_turns_list)),
"rollout/reward/mean": float(_stats.mean(rewards_list)),
"rollout/reward/nonzero_frac": (
nonzero / len(rewards_list) if rewards_list else 0.0
),
"rollout/n_samples": float(len(rewards_list)),
}
wandb.log(log_payload, commit=False)
except Exception as exc: # never let metric logging break training
LOG.warning("rollout wandb log failed: %s", exc)
# Expand inputs to match expanded rollouts (num_generations copies)
expanded_inputs = []
for inp in inputs:
for g in range(self._num_generations):
expanded_inp = dict(inp)
expanded_inp["env_reward"] = inp["env_reward"][g]
expanded_inputs.append(expanded_inp)
# Decode completions for reward functions
completions = trainer.processing_class.batch_decode(

View File

@@ -19,7 +19,6 @@ Supports two modes:
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Union
from axolotl.integrations.base import BasePlugin
@@ -31,107 +30,6 @@ if TYPE_CHECKING:
LOG = get_logger(__name__)
# ---- vLLM weight-sync transport probe ------------------------------------
@dataclass
class VLLMWeightSyncCapabilities:
"""What weight-sync routes a vLLM server actually exposes.
Discovered once at ``pre_model_load`` time by fetching the server's
``/openapi.json``. Drives the transport-selection table below.
"""
nccl: bool = False # /init_communicator/ + /update_named_param/
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
probed: bool = False
probe_error: str | None = None
routes: list[str] = field(default_factory=list)
@property
def any_full_param_sync(self) -> bool:
"""True if at least one transport can push full-model weights."""
return self.nccl or self.http_full
@property
def any_lora_sync(self) -> bool:
"""True if at least one transport can push LoRA adapters."""
return self.lora_filesystem or self.lora_axolotl or self.nccl
def probe_vllm_weight_sync(
base_url: str, timeout: float = 5.0
) -> VLLMWeightSyncCapabilities:
"""Detect which weight-sync routes the configured vLLM server exposes.
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
we care about is mounted as a POST route there. Falls back to all-False
on any error so the caller can still decide what to do (typically: raise
a clear error rather than silently no-op).
"""
import requests
caps = VLLMWeightSyncCapabilities()
try:
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
r.raise_for_status()
spec = r.json()
routes = sorted((spec.get("paths") or {}).keys())
caps.routes = routes
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
caps.lora_axolotl = "/set_lora_adapter/" in routes
caps.http_full = "/http_update_weights/" in routes
caps.probed = True
except Exception as exc:
caps.probe_error = f"{type(exc).__name__}: {exc}"
LOG.warning(
"NeMo Gym: failed to probe vLLM /openapi.json at %s%s. "
"Will fall back to LoRA-only behavior.",
base_url,
caps.probe_error,
)
return caps
def select_weight_sync_transport(
caps: VLLMWeightSyncCapabilities,
*,
has_lora: bool,
vllm_lora_sync_pref: bool,
) -> str:
"""Pick the right transport for a (server caps, model type) combo.
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
``"none"``. The caller decides what to do with ``"none"`` (typically:
raise an error explaining the misconfiguration).
Selection table:
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
LoRA model + lora endpoint → lora_filesystem
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
Full model + nccl endpoint → nccl
Full model + http endpoint → http_full
anything else → none
"""
if has_lora:
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
return "lora_filesystem"
if caps.lora_filesystem or caps.lora_axolotl:
return "lora_filesystem"
if caps.nccl:
return "nccl"
return "none"
# Full-parameter model
if caps.nccl:
return "nccl"
if caps.http_full:
return "http_full"
return "none"
class NemoGymPlugin(BasePlugin):
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
@@ -152,69 +50,37 @@ class NemoGymPlugin(BasePlugin):
self._reward_fn = None
self._dataset_lookup = None
self._agent_servers = {}
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
def get_input_args(self):
return "axolotl.integrations.nemo_gym.NemoGymArgs"
def pre_model_load(self, cfg):
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
Replaces the previous unconditional ``init_communicator`` monkey-patch
with a probe of the configured vLLM server's ``/openapi.json``. We only
bypass NCCL init when the server we're talking to actually lacks the
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
standard TRL flow alone so full-finetune training can sync weights.
"""
"""Apply monkeypatches before trainer creation."""
if not cfg.nemo_gym_enabled:
return
# Always skip NCCL communicator init in NeMo Gym mode.
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
trl_cfg = getattr(cfg, "trl", None)
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
return
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
base_url = f"http://{host}:{port}"
self._vllm_caps = probe_vllm_weight_sync(base_url)
if self._vllm_caps.probed:
LOG.info(
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
"lora_axolotl=%s http_full=%s",
base_url,
self._vllm_caps.nccl,
self._vllm_caps.lora_filesystem,
self._vllm_caps.lora_axolotl,
self._vllm_caps.http_full,
)
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
# available we leave VLLMClient.init_communicator alone so the
# standard TRL sync flow can run for full-parameter training.
if not self._vllm_caps.nccl:
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
self._patch_skip_nccl_init()
def _patch_skip_nccl_init(self):
"""Monkeypatch VLLMClient.init_communicator to no-op.
Only called when the configured vLLM server doesn't expose
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
TRL's standard ``init_communicator`` would 404 inside trainer
construction; we no-op it so the LoRA filesystem path can install
its own sync in ``post_trainer_create``.
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
serve script). The NCCL communicator is not needed and fails with both
vLLM V1 engine and standard OpenAI server mode.
"""
try:
from trl.generation.vllm_client import VLLMClient
VLLMClient._original_init_communicator = VLLMClient.init_communicator
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
"Skipping NCCL init_communicator (server has no /init_communicator/)"
)
LOG.info(
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
"Skipping NCCL init_communicator (LoRA sync mode)"
)
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
except Exception as exc:
LOG.warning(f"Failed to patch VLLMClient: {exc}")
@@ -368,80 +234,30 @@ class NemoGymPlugin(BasePlugin):
verify_timeout = cfg.nemo_gym_verify_timeout or 30
multi_turn = cfg.nemo_gym_multi_turn or False
# Pick a weight-sync transport based on what the configured vLLM
# server actually exposes (see ``pre_model_load`` probe) and what
# kind of model we're training. The selection table is documented
# in ``select_weight_sync_transport``.
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
# - Install LoRA sync (when vllm_lora_sync=True)
# - Or no-op sync_weights (when using standard vLLM server)
trl_cfg = getattr(cfg, "trl", None)
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
vllm_gen = trainer.vllm_generation
adapter = getattr(cfg, "adapter", None)
has_lora = adapter in ("lora", "qlora")
vllm_lora_sync_pref = bool(
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
)
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
transport = select_weight_sync_transport(
caps,
has_lora=has_lora,
vllm_lora_sync_pref=vllm_lora_sync_pref,
)
if transport == "lora_filesystem":
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
self._setup_lora_sync(trainer)
# Verify the vLLM server supports runtime LoRA loading
self._check_lora_endpoint(vllm_gen)
LOG.info("NeMo Gym weight sync: LoRA filesystem")
elif transport == "nccl":
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
# alone (pre_model_load only patched it when the probe found no
# NCCL route) so the trainer's normal weight-sync flow runs.
LOG.info(
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
else:
# No NCCL, no LoRA sync — skip all weight sync paths
vllm_gen.sync_weights = lambda: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
)
elif transport == "http_full":
# Full-parameter HTTP sync — implementation lands in step 3.
# For now, fail loudly so users know the path is detected but
# not yet wired up, instead of silently no-oping like before.
raise NotImplementedError(
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
"but the client-side sync helper is not yet implemented "
"(planned). Use `adapter: lora|qlora` for now, or use a "
"vLLM serve module that exposes /init_communicator/ for "
"NCCL sync."
type(vllm_gen).sync_weights = lambda self: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
)
else: # transport == "none"
# No viable sync path. Build a precise error so the user knows
# exactly what's missing and how to fix it.
if not caps.probed:
msg = (
"could not probe the vLLM server's "
f"/openapi.json: {caps.probe_error}. "
"Verify that vLLM is reachable at "
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
# Also patch the async trainer's internal sync method
if hasattr(trainer, "_maybe_sync_vllm_weights"):
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
"Async weight sync skipped (NeMo Gym mode)"
)
elif has_lora:
msg = (
"the vLLM server has neither NCCL routes "
"(/init_communicator/) nor a LoRA-loading route "
"(/v1/load_lora_adapter or /set_lora_adapter/). "
"Restart vLLM with `--enable-lora --max-lora-rank N "
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
"server, or use `axolotl vllm-serve` for the "
"NCCL-capable serve module."
)
else:
msg = (
"the vLLM server exposes no full-parameter sync route "
"(/init_communicator/ for NCCL or /http_update_weights/ "
"for HTTP). Use `axolotl vllm-serve` (which has both) "
"or set `adapter: lora|qlora`."
)
raise ValueError(
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
"weight sync the trainer's gradient updates never reach the "
"rollout policy (functionally a no-op trainer)."
)
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
if multi_turn:
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)

View File

@@ -130,41 +130,21 @@ def start_servers(
)
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
def get_server_configs(head_port: int = 11000) -> dict:
"""Fetch the global config from the NeMo Gym head server.
Retries up to 3 times with exponential backoff. The default per-attempt
timeout is 30s (raised from the original 5s) because head servers can
be slow to respond when they're concurrently serving rollouts from a
prior training run. A 5s timeout was empirically too tight to survive
a kill-and-relaunch cycle.
Returns:
Dict mapping server_name -> server config.
"""
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
last_exc: Exception | None = None
for attempt in (1, 2, 3):
try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()
result = yaml.safe_load(response.text)
# NeMo Gym head server double-encodes: YAML string inside a YAML string
if isinstance(result, str):
result = yaml.safe_load(result)
return result
except (requests.exceptions.RequestException, OSError) as exc:
last_exc = exc
LOG.warning(
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
attempt,
type(exc).__name__,
)
if attempt < 3:
time.sleep(2.0 * attempt)
raise RuntimeError(
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
response = requests.get(
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
)
response.raise_for_status()
result = yaml.safe_load(response.text)
# NeMo Gym head server double-encodes: YAML string inside a YAML string
if isinstance(result, str):
result = yaml.safe_load(result)
return result
def get_agent_servers(

View File

@@ -53,7 +53,6 @@ def _rms_norm_rope_forward_kernel(
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
@@ -61,35 +60,28 @@ def _rms_norm_rope_forward_kernel(
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
y[..., :n_rot] = rope(x_norm[..., :n_rot])
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
x_norm = x / rms(x) [* weight] (RMSNorm)
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
rotate_half swaps first/second halves and negates the first, restricted
to the rotary span [0, n_rot):
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
For the partial-rotary pass-through region we load cos with default 1.0
and sin with default 0.0 outside [0, n_rot), so the same formula
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
rotate_half swaps first/second halves and negates the first:
rotate_half([a, b]) = [-b, a]
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
half_dim = n_cols // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
# RMSNorm: compute 1/rms
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
@@ -102,38 +94,33 @@ def _rms_norm_rope_forward_kernel(
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
# cos=1, sin=0 so the formula leaves X_norm untouched.
# RoPE: load cos/sin (broadcast across heads)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
mask=rot_mask_col,
other=0.0,
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# rotate_half within [0, n_rot):
# for col < half_rot: take -X_norm[col + half_rot]
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
# for col >= half_dim, take X_norm[col - half_dim]
rot_offsets = tl.where(
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
rot_mask = rot_offsets < n_cols
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
tl.float32
)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
@@ -166,21 +153,13 @@ def _rms_norm_rope_backward_kernel(
dW_row_stride,
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
(`n_rot <= n_cols`).
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
output is just the normalized row, so dN[col] = dY[col] (achieved by
loading cos with default 1.0 and forcing the rotate-half contribution
to zero outside the rotary span).
Backward for Y = RoPE(RMSNorm(X, W))
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
@@ -188,8 +167,7 @@ def _rms_norm_rope_backward_kernel(
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
half_dim = n_cols // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
@@ -208,37 +186,33 @@ def _rms_norm_rope_backward_kernel(
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
# dN = dY * cos + rotate_half^T(dY * sin)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# For col >= n_rot the formula must collapse to dN = dY (since the
# forward is just a pass-through). cos defaults to 1.0 above; the
# rotate-half contribution is masked to zero below.
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
# This is equivalent to rotating (dY * sin) because the rotation
# just permutes which elements are multiplied.
rot_offsets = tl.where(
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
rot_mask = rot_offsets < n_cols
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_load_mask,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_load_mask,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
rotate_term = dY_rot * sin_rot * adj_sign
# Zero out rotate-half contribution outside the rotary span.
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
dN = dY_row * cos_row + rotate_term
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
@@ -267,17 +241,15 @@ def _rms_norm_rope_backward_kernel(
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
partial rotary). Must be even and ``<= head_dim``.
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
@@ -301,7 +273,6 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
RSTD,
RSTD.stride(0),
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT=has_weight,
@@ -311,9 +282,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
):
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
n_rows, n_cols = dY.shape
has_weight = W is not None
@@ -346,7 +315,6 @@ def rms_norm_rope_backward(
_dW.stride(0),
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
@@ -361,14 +329,13 @@ def rms_norm_rope_backward(
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
def forward(ctx, X, W, cos, sin, eps, n_heads):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, n_rot) — broadcast across heads
sin: (B*S, n_rot) — broadcast across heads
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, head_dim) — broadcast across heads
sin: (B*S, head_dim) — broadcast across heads
n_heads: int
n_rot: int — rotary dim (<= head_dim)
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
@@ -377,13 +344,11 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
sin,
eps,
n_heads,
n_rot,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.n_rot = n_rot
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@@ -400,26 +365,21 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
sin,
RSTD,
ctx.n_heads,
ctx.n_rot,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None, None
return dX, dW, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + (partial) RoPE.
Apply fused RMSNorm + RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
must be even and ``<= head_dim``. When ``n_rot < head_dim``
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
(partial-rotary pass-through), matching stock Gemma 4 with
``partial_rotary_factor < 1.0``.
sin: (batch, seq_len, n_rot) — same shape as ``cos``
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
eps: float — RMSNorm epsilon
Returns:
@@ -427,38 +387,14 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
n_rot = cos.shape[-1]
if sin.shape[-1] != n_rot:
raise ValueError(
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
f"sin={sin.shape[-1]}"
)
if n_rot > D:
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
if n_rot % 2 != 0:
raise ValueError(f"rotary dim must be even, got {n_rot}")
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
# all sequences share the same rotary positions). The kernel needs a
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
# onto a single (b, s) pair, so expand-then-contiguous to materialize
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
if cos.shape[0] != B:
if cos.shape[0] != 1:
raise ValueError(
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
f"to x batch dim ({B})"
)
cos = cos.expand(B, S, n_rot)
sin = sin.expand(B, S, n_rot)
cos_flat = cos.reshape(B * S, n_rot).contiguous()
sin_flat = sin.reshape(B * S, n_rot).contiguous()
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
# by dividing the row_idx by H to get the cos/sin row
cos_flat = cos.reshape(B * S, D).contiguous()
sin_flat = sin.reshape(B * S, D).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
)
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
return y_flat.view(shape)

View File

@@ -156,14 +156,6 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
# override needs to walk the raw model tree, which is broken by
# the post-load PEFT wrapping. The accompanying
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
# installation-time-independent, so both halves of the fix live
# cleanly in the same call even though one is instance-scoped
# and the other is module-scoped.
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
@@ -180,23 +172,12 @@ class PatchManager:
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
which fixes the corresponding mask construction inside
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
override is not enough — the forward still builds a 2D FA2-format mask
at the model level and the SDPA layers crash at long context lengths
with ``RuntimeError: The expanded size of the tensor ... must match``.
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
@@ -410,14 +391,6 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
# The fused attn path is now compatible with
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
# rotary (cos.shape[-1] < head_dim) and the fused forward
# mirrors the current ``Gemma4TextAttention.forward`` API
# for shared kv (read from / write to
# ``past_key_values.shared_layers``). See
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
# for the history.
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)

View File

@@ -23,8 +23,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.processor_kwargs:
processor_kwargs.update(cfg.processor_kwargs)
if cfg.tokenizer_use_mistral_common:

View File

@@ -1,115 +0,0 @@
"""Hybrid attention mask fix for Gemma 4.
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
``self_attn.config``, leaving sliding-window layers on FA2.
The per-layer config override alone is insufficient, however:
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
using the **model-level** config and passes the mapped mask to each
decoder layer. With FA2 still set at the model level, the ``full_attention``
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
The global layers then fail with::
RuntimeError: The expanded size of the tensor (S) must match the existing
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
sizes: [B, S]
...when the sequence length grows past roughly 7k tokens.
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
the original in ``masking_utils``. The wrapper forces
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
alone, so sliding-window layers continue to receive FA2-format masks.
The patch is idempotent. Install once per process, before any Gemma 4
forward pass runs.
"""
from __future__ import annotations
import copy
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
_PATCH_APPLIED = False
def patch_gemma4_hybrid_mask() -> bool:
"""Install the Gemma 4 hybrid-attention mask fix.
Returns ``True`` if the patch was installed (or was already installed),
``False`` if the target module could not be imported (e.g. transformers
version predates Gemma 4) — in which case nothing is done and the
caller can continue unaffected.
"""
global _PATCH_APPLIED
if _PATCH_APPLIED:
return True
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
LOG.debug(
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
"skipping. This is fine for non-Gemma4 training."
)
return False
if not hasattr(modeling_gemma4, "create_causal_mask"):
LOG.warning(
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
"binding, skipping. Transformers API may have changed."
)
return False
original = modeling_gemma4.create_causal_mask
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
"""Wrapper that forces SDPA format for the full-attention mask.
The global layers were patched to SDPA by
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
original ``create_causal_mask`` dispatches on
``config._attn_implementation``; we shadow that with a local
override.
"""
sdpa_config = copy.copy(config)
sdpa_config._attn_implementation = "sdpa"
return original(sdpa_config, *args, **kwargs)
# Preserve the original reference on the wrapper for tests / teardown.
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
_PATCH_APPLIED = True
LOG.info(
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
"force SDPA-format masks for full-attention layers"
)
return True
def unpatch_gemma4_hybrid_mask() -> None:
"""Restore the original ``create_causal_mask``. Useful for tests."""
global _PATCH_APPLIED
if not _PATCH_APPLIED:
return
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
_PATCH_APPLIED = False
return
current = modeling_gemma4.create_causal_mask
original = getattr(current, "_axolotl_original", None)
if original is not None:
modeling_gemma4.create_causal_mask = original
_PATCH_APPLIED = False

View File

@@ -24,15 +24,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
# language-side module is separated from the vision tower. Try
# both names before giving up.
mlp_cls = getattr(
module,
f"{model_cls_prefix}MLP",
None,
) or getattr(module, f"{model_cls_prefix}TextMLP")
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
if use_original_mlp:
mlp_forward = mlp_cls.forward

View File

@@ -407,10 +407,7 @@ def selective_log_softmax(logits, index) -> torch.Tensor:
K = index.shape[-1]
original_index_shape = index.shape
try:
flat_logits = logits.view(-1, V)
except RuntimeError:
flat_logits = logits.reshape(-1, V).contiguous()
flat_logits = logits.reshape(-1, V).contiguous()
flat_index = index.reshape(-1, K).contiguous()
BLOCK_V = 4096

View File

@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
@@ -1004,13 +1004,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if tools is None:
return None
# Some datasets have tools set to str
if isinstance(tools, str):
try:
tools = json.loads(tools)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
raise
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
@@ -1041,22 +1034,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing messages as JSON. Error: {e}")
raise
assert isinstance(messages, list), (
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
)
# Extra check here to make sure decoded json is a list of dicts.
for i, message in enumerate(messages):
assert isinstance(message, dict), (
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
)
if isinstance(messages, list):
return messages

View File

@@ -320,15 +320,6 @@ def main(script_args: ScriptArguments):
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
# Serializes access to the worker pipe. The underlying
# multiprocessing.Connection is a single full-duplex stream shared
# across all HTTP handlers; concurrent requests interleave bytes on
# the wire and corrupt the pickle framing (seen as
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
# the round-trip so only one inflight call at a time per pipe.
worker_pipe_lock = asyncio.Lock()
# ------------------------------------------------------------------
# LoRA-specific endpoints
# ------------------------------------------------------------------
@@ -640,150 +631,6 @@ def main(script_args: ScriptArguments):
},
}
@app.post("/v1/completions")
async def openai_completions(request_body: dict):
"""OpenAI-compatible text-completions endpoint.
Accepts either a string ``prompt`` or a list-of-int
``prompt_token_ids`` (as the text-completions spec allows). Routes
to the internal vLLM generate method with the active LoRA adapter
and returns an OpenAI /v1/completions-shaped response including
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
``generation_log_probs`` for NeMo Gym agents that need raw
tokens + logprobs.
"""
import uuid
prompt_raw = request_body.get("prompt")
temperature = request_body.get("temperature", 1.0)
max_tokens = request_body.get("max_tokens", 512)
top_p = request_body.get("top_p", 1.0)
n = request_body.get("n", 1)
logprobs = request_body.get("logprobs") or 0
stop_token_ids = request_body.get("stop_token_ids") or None
# Accept either a string or a list[int] token id prompt. Lists
# must contain ints only (raise on lists of strings so callers get
# a clear error). Also accept [[int, int, ...]] nesting for the
# rare case callers pass a single-prompt batch.
if (
isinstance(prompt_raw, list)
and prompt_raw
and isinstance(prompt_raw[0], list)
):
prompt_raw = prompt_raw[0]
prompt_dict: dict[str, Any] = {}
if isinstance(prompt_raw, list):
prompt_dict = {"prompt_token_ids": prompt_raw}
elif isinstance(prompt_raw, str):
prompt_dict = {"prompt": prompt_raw}
else:
return {
"error": {
"message": ("prompt must be a string or a list of token ids"),
"type": "invalid_request",
}
}
generation_kwargs: dict[str, Any] = {
"n": n,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"logprobs": logprobs,
}
if stop_token_ids:
generation_kwargs["stop_token_ids"] = stop_token_ids
sampling_params = SamplingParams(
**{k: v for k, v in generation_kwargs.items() if v is not None}
)
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
# Hold the pipe lock across send+recv — concurrent requests would
# otherwise interleave pickle frames on the worker connection.
async with worker_pipe_lock:
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [{"prompt": "<placeholder>"}]
kwargs = {
"prompts": chunk,
"sampling_params": sampling_params,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
)
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
for o in all_outputs:
if isinstance(o, dict) and "error" in o:
raise RuntimeError(f"vLLM worker error: {o['error']}")
all_outputs = list(chain.from_iterable(all_outputs))
if not all_outputs:
return {"choices": [], "model": script_args.model}
choices = []
for i, output in enumerate(all_outputs):
for j, out in enumerate(output.outputs):
text = out.text
# OpenAI-style `logprobs` block for text-completions:
# { "tokens": [...], "token_logprobs": [...] }
lp_block = None
if out.logprobs:
tokens_str: list[str] = []
token_lps: list[float] = []
for step in out.logprobs:
chosen = next(iter(step.values()))
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
token_lps.append(float(chosen.logprob))
lp_block = {
"tokens": tokens_str,
"token_logprobs": token_lps,
}
choice = {
"index": i * n + j,
"text": text,
"finish_reason": "stop"
if out.finish_reason == "stop"
else "length",
"logprobs": lp_block,
# NeMo-Gym / retrace agent extras — preserved on the
# choice so callers with raw-token pipelines don't
# have to re-tokenize.
"prompt_token_ids": output.prompt_token_ids,
"generation_token_ids": list(out.token_ids),
"generation_log_probs": (
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
if out.logprobs
else []
),
}
choices.append(choice)
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
completion_tokens = sum(
len(out.token_ids) for o in all_outputs for out in o.outputs
)
return {
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
"object": "text_completion",
"model": script_args.model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")

View File

@@ -6,7 +6,6 @@ from .batching import (
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .dpo import AxolotlDPODataCollatorWithPadding
from .mamba import MambaDataCollator
__all__ = [
@@ -14,6 +13,5 @@ __all__ = [
"BatchSamplerDataCollatorForSeq2Seq",
"V2BatchSamplerDataCollatorForSeq2Seq",
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
"AxolotlDPODataCollatorWithPadding",
"MambaDataCollator",
]

View File

@@ -1,128 +0,0 @@
"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support.
Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths
up to a fixed multiple. This stabilizes Triton autotune caches for kernels
that key on sequence length (e.g. fla's linear attention kernels used by
Qwen3.5), which otherwise re-autotune on every distinct batch length.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch.nn.utils.rnn import pad_sequence
from trl.experimental.utils import DPODataCollatorWithPadding
from trl.trainer.utils import pad
def _round_up(length: int, multiple: int) -> int:
return ((length + multiple - 1) // multiple) * multiple
@dataclass
class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding):
"""DPO data collator that pads to a multiple of ``pad_to_multiple_of``.
Args:
pad_token_id: Tokenizer pad token id (inherited).
is_encoder_decoder: Whether the model is encoder-decoder (inherited).
pad_to_multiple_of: If set, padded lengths are rounded up to this
multiple. Helps stabilize Triton autotune caches.
"""
pad_to_multiple_of: int | None = None
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
pad_to_mult = self.pad_to_multiple_of
padded_batch: dict[str, Any] = {}
for k in features[0].keys():
if k.endswith(
("_input_ids", "_attention_mask", "_labels", "_pixel_values")
):
if self.is_encoder_decoder:
if k.endswith("_pixel_values"):
to_pad = [
torch.tensor(ex[k], dtype=torch.float32) for ex in features
]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in features]
if k.startswith("prompt") and k.endswith("input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.endswith("_pixel_values"):
padding_value = 0
elif (
k.startswith(("chosen", "rejected", "completion"))
or "decoder" in k
):
padding_value = -100
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded = pad_sequence(
to_pad, batch_first=True, padding_value=padding_value
)
if pad_to_mult:
cur = padded.shape[1]
target = _round_up(cur, pad_to_mult)
if target > cur:
extra = target - cur
pad_shape = list(padded.shape)
pad_shape[1] = extra
filler = torch.full(
pad_shape,
padding_value,
dtype=padded.dtype,
device=padded.device,
)
padded = torch.cat([padded, filler], dim=1)
padded_batch[k] = padded
else:
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = -100
elif k.endswith("_attention_mask"):
padding_value = 0
elif k.endswith("_pixel_values"):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padding_side = (
"left"
if k in ("prompt_input_ids", "prompt_attention_mask")
else "right"
)
dtype = (
torch.float32 if k.endswith("_pixel_values") else torch.int64
)
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
# trl.pad() natively supports pad_to_multiple_of
padded_batch[k] = pad(
to_pad,
padding_value=padding_value,
padding_side=padding_side,
pad_to_multiple_of=pad_to_mult,
)
elif k.endswith("_logps"):
padded_batch[k] = torch.tensor([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
return padded_batch

View File

@@ -309,16 +309,6 @@ class AxolotlInputConfig(
dpo_padding_free: bool | None = None
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)
datasets: (
Annotated[
list[
@@ -673,12 +663,6 @@ class AxolotlInputConfig(
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
},
)
pad_to_multiple_of: int | None = Field(
default=None,
json_schema_extra={
"description": ("Pad each batch to a multiple of this value.")
},
)
curriculum_sampling: bool | None = Field(
default=None,
json_schema_extra={
@@ -1016,7 +1000,7 @@ class AxolotlInputConfig(
torch_compile: Literal["auto"] | bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use torch.compile and which backend to use."
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
},
)
torch_compile_backend: str | None = Field(

View File

@@ -64,12 +64,6 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
processor_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
},
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
@@ -113,22 +107,6 @@ class ModelInputConfig(BaseModel):
)
return trust_remote_code
@field_validator("processor_kwargs")
@classmethod
def reject_reserved_processor_kwargs(cls, processor_kwargs):
if not processor_kwargs:
return processor_kwargs
reserved = {"revision", "trust_remote_code"}
conflicts = reserved.intersection(processor_kwargs)
if conflicts:
raise ValueError(
"Do not set reserved keys "
f"{sorted(conflicts)} inside `processor_kwargs`; "
"use the top-level `revision_of_model` / `trust_remote_code` "
"config keys instead."
)
return processor_kwargs
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""

View File

@@ -578,11 +578,6 @@ class TrainingValidationMixin:
"Setting chat_template is not supported with mistral-common tokenizer"
)
if data.get("processor_kwargs"):
raise ValueError(
"processor_kwargs is not supported with mistral-common tokenizer"
)
return data
@model_validator(mode="before")
@@ -765,122 +760,6 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_dpo(cls, data):
dpo_loss_type = data.get("dpo_loss_type")
dpo_loss_weights = data.get("dpo_loss_weights")
rl = data.get("rl")
if rl == "ipo":
LOG.warning(
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
)
if rl == "dpo":
if dpo_loss_weights is not None and dpo_loss_type is None:
raise ValueError(
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
)
if (
dpo_loss_type is not None
and dpo_loss_weights is not None
and len(dpo_loss_type) != len(dpo_loss_weights)
):
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
)
elif dpo_loss_type is not None or dpo_loss_weights is not None:
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
)
return data
@model_validator(mode="before")
@classmethod
def check_grpo_batch_size_divisibility(cls, data):
"""Surface GRPO batch-shape mismatches at config-parse time.
TRL's GRPOTrainer requires that the per-step generation batch size be
evenly divisible by ``num_generations`` so that every prompt can be
replicated exactly ``num_generations`` times. The runtime check inside
``GRPOTrainer.__init__`` only fires after the model has been loaded —
too late and too cryptic for the user. We replicate the check here so
the failure is immediate and actionable.
Also enforces:
- ``num_generations >= 2`` (group-relative advantage needs variance)
- ``effective_gbs >= num_generations * world_size`` when capabilities
indicate multiple ranks (each rank needs at least one full group)
"""
if data.get("rl") != "grpo":
return data
trl_cfg = data.get("trl") or {}
num_gen = trl_cfg.get("num_generations")
if num_gen is None:
# TRL's own default is 8 — but if the user didn't set it, we
# don't have enough info to validate anything. Let TRL's own
# init handle the default-vs-batch interaction.
return data
if num_gen < 2:
raise ValueError(
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
"With num_generations=1, every group has zero advantage and "
"the policy never updates."
)
explicit_gbs = trl_cfg.get("generation_batch_size")
if explicit_gbs is not None:
effective_gbs = int(explicit_gbs)
gbs_source = "trl.generation_batch_size"
else:
mb = data.get("micro_batch_size") or 1
ga = data.get("gradient_accumulation_steps") or 1
effective_gbs = int(mb) * int(ga)
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
if effective_gbs % num_gen != 0:
# Suggest the smallest GA bump that fixes it for the common case
# where the user hasn't set generation_batch_size explicitly.
hint = ""
if explicit_gbs is None:
from math import gcd
mb_val = int(data.get("micro_batch_size") or 1)
# smallest GA such that mb*GA is a multiple of num_gen
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
suggested_ga = lcm // mb_val
hint = (
f" Smallest fix: set `gradient_accumulation_steps: "
f"{suggested_ga}` (so micro_batch_size * GA = "
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
)
raise ValueError(
f"GRPO: generation batch size must be divisible by "
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
)
# Multi-rank check: each rank must receive at least one full group
# per step. Without `capabilities` populated yet (mode='before'), we
# fall back to user-set distributed fields.
world_size = (
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
)
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
raise ValueError(
f"GRPO with world_size={world_size} requires effective_gbs "
f">= num_generations * world_size = {num_gen * world_size}, "
f"got {effective_gbs}. Increase gradient_accumulation_steps "
f"or micro_batch_size."
)
return data
class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""

View File

@@ -216,197 +216,5 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
self.assertIs(_trainer_module.validate_quantization_for_training, original)
class TestVllmLoraSyncPatch(unittest.TestCase):
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
These tests exercise the patch-installation branch in isolation. They build
a stub trainer with just enough attributes to look like
``AsyncGRPOTrainer`` for the duration of the relevant code path.
Background — there are two correct behaviors and we historically had a bug
where both modes used the same one:
- Async prefetch ON: the BG generation thread can't safely call
sync_weights mid-rollout. We no-op the stock hook and drive sync from
the main thread via ``_maybe_sync_vllm_weights``.
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
calls ``sync_weights`` once per step boundary on the main thread. We
wire that hook directly to ``_sync_lora_adapter`` because
``_maybe_sync_vllm_weights`` short-circuits when async is off.
Before the fix, both modes installed ``lambda: None``, so sync mode never
pushed any LoRA adapter to vLLM and the trainer was a no-op.
"""
@staticmethod
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.vllm_lora_sync = vllm_lora_sync
args.async_prefetch = async_prefetch
class FakeVllmGen:
sync_weights = staticmethod(lambda: None)
model = MagicMock()
# Use object.__new__ so we don't run __init__ (which needs a real
# model, dataset, etc.). We only need the `_generate_single_turn`
# method's patch branch to run, so we set up the minimum state.
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.vllm_generation = FakeVllmGen()
trainer._patched_sync_weights = False
# Spy on _sync_lora_adapter so we can assert it's the function the
# hook delegates to in sync mode.
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
trainer._sync_peft_weights_no_merge = MagicMock(
name="_sync_peft_weights_no_merge_spy"
)
return trainer
@staticmethod
def _run_patch_branch(trainer):
"""Execute just the sync_weights-patching branch in isolation.
We can't easily call the real ``_generate_single_turn`` because it
does a full vLLM generate. Instead we copy the exact branch out of
the source so the test verifies the same logic the trainer runs.
"""
if not getattr(trainer, "_patched_sync_weights", False):
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
if getattr(trainer.args, "vllm_lora_sync", False):
if getattr(trainer.args, "async_prefetch", False):
trainer.vllm_generation.sync_weights = lambda: None
else:
sync_helper = trainer._sync_lora_adapter
def _lora_filesystem_sync():
sync_helper()
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
trainer._patched_sync_weights = True
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Trigger the patched hook — it must call _sync_lora_adapter.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_lora_sync_installs_noop_hook(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Hook must be a no-op so BG-thread generation doesn't fight the
# main-thread optimizer step over the model weights.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_not_called()
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
"""Installing the patch should not pre-emptively sync."""
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
# _sync_lora_adapter should only be called when the patched hook
# itself is invoked (e.g., from TRL's _generate_single_turn).
trainer._sync_lora_adapter.assert_not_called()
def test_patch_is_idempotent(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
first_hook = trainer.vllm_generation.sync_weights
# Second call must not re-patch (otherwise we'd lose the original).
self._run_patch_branch(trainer)
assert trainer.vllm_generation.sync_weights is first_hook
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
on the very first call when ``vllm_sync_interval`` was ``None`` (which
is the default for any config that doesn't explicitly set it). We now
fall back to interval=1 so unset means "sync every step", matching the
behavior of TRL's own ``_generate_single_turn``.
"""
@staticmethod
def _make_stub_trainer(interval, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.async_prefetch = async_prefetch
args.vllm_sync_interval = interval
args.vllm_lora_sync = True
class FakeState:
global_step = 1
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.state = FakeState()
trainer._last_synced_step = 0
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
return trainer
def test_interval_none_in_async_mode_does_not_crash(self):
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# Should not raise TypeError — defaults to every-step sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_sync_mode_drives_sync(self):
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
The previous behavior (early return when ``not async_prefetch``)
assumed TRL's stock ``_generate_single_turn`` would handle sync.
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
where the data producer bypasses ``_generate_single_turn``
entirely. Without this trigger no sync ever happens and the
trainer becomes a no-op.
"""
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_explicit_interval_respects_modulo(self):
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_not_called()
# global_step=4 → 4 % 4 == 0 → sync
trainer.state.global_step = 4
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -96,8 +96,6 @@ def fixture_dpo_cfg(base_cfg):
"dpo_use_weighting": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 0.5],
}
)
return cfg
@@ -166,8 +164,7 @@ def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.DPO,
"dpo_loss_type": ["ipo"],
"rl": RLType.IPO,
"dpo_label_smoothing": 0,
"beta": 0.1,
}
@@ -303,8 +300,6 @@ class TestHFRLTrainerBuilder:
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
assert training_arguments.precompute_ref_log_probs is True
assert training_arguments.loss_type == ["sigmoid", "sft"]
assert training_arguments.loss_weights == [1.0, 0.5]
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)

View File

@@ -116,58 +116,6 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_rpo(self, temp_dir):
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
# replaces loss_type="rpo", rpo_alpha=alpha.
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 1.0],
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
@@ -233,8 +181,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"dpo_loss_type": ["ipo"],
"rl": "ipo",
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",

View File

@@ -361,329 +361,6 @@ class TestPluginDefaults(unittest.TestCase):
assert cfg.dataloader_num_workers == 0
class TestSelectWeightSyncTransport(unittest.TestCase):
"""Pure-logic table tests for ``select_weight_sync_transport``."""
def _caps(self, **kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_native_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_filesystem=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "lora_filesystem"
)
def test_lora_with_axolotl_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_axolotl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "lora_filesystem"
)
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "nccl"
)
def test_full_param_prefers_nccl(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True, http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "nccl"
)
def test_full_param_falls_back_to_http(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "http_full"
)
def test_full_param_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps() # all False
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "none"
)
def test_lora_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps()
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "none"
)
class TestProbeVllmWeightSync(unittest.TestCase):
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
def test_stock_vllm_with_lora_enabled(self):
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/v1/models": {"get": {}},
"/v1/load_lora_adapter": {"post": {}},
"/v1/unload_lora_adapter": {"post": {}},
"/v1/completions": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.lora_filesystem is True
assert caps.lora_axolotl is False
assert caps.nccl is False
assert caps.http_full is False
def test_axolotl_serve_lora_full_capabilities(self):
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/set_lora_adapter/": {"post": {}},
"/clear_lora_adapter/": {"post": {}},
"/http_update_weights/": {"post": {}},
"/v1/load_lora_adapter": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_axolotl is True
assert caps.lora_filesystem is True
assert caps.http_full is True
def test_trl_vllm_serve_nccl_only(self):
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/close_communicator/": {"post": {}},
"/generate/": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_filesystem is False
assert caps.lora_axolotl is False
assert caps.http_full is False
def test_unreachable_server_records_error(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection refused")
caps = probe_vllm_weight_sync("http://localhost:9999")
assert caps.probed is False
assert caps.probe_error is not None
assert "ConnectionError" in caps.probe_error
assert caps.nccl is False
assert caps.lora_filesystem is False
class TestPluginWeightSyncEnforcement(unittest.TestCase):
"""End-to-end test of post_trainer_create's transport-selection branch.
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
leaving the trainer learning in isolation while vLLM kept serving the
unmodified base model. After the fix:
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
- LoRA + only NCCL endpoint → uses NCCL broadcast
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
- No usable transport → raises ValueError with a precise diagnosis
"""
@staticmethod
def _fake_cfg(adapter, vllm_lora_sync):
class FakeTRL:
pass
class FakeCfg:
pass
trl = FakeTRL()
trl.vllm_lora_sync = vllm_lora_sync
trl.vllm_server_host = "127.0.0.1"
trl.vllm_server_port = 8000
cfg = FakeCfg()
cfg.nemo_gym_enabled = True
cfg.nemo_gym_model_name = None
cfg.base_model = "test/model"
cfg.nemo_gym_verify_timeout = 30
cfg.nemo_gym_multi_turn = True
cfg.adapter = adapter
cfg.trl = trl
return cfg
@staticmethod
def _fake_trainer():
class FakeVLLMGen:
sync_weights = staticmethod(lambda: None)
class FakeTrainer:
vllm_generation = FakeVLLMGen()
return FakeTrainer()
@staticmethod
def _caps(**kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(lora_filesystem=True)
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with (
patch.object(plugin, "_setup_lora_sync") as setup,
patch.object(plugin, "_check_lora_endpoint") as check,
patch.object(plugin, "_wire_multi_turn") as wire,
):
plugin.post_trainer_create(cfg, trainer)
setup.assert_called_once()
check.assert_called_once()
wire.assert_called_once()
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps() # all False, but probed
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "load_lora_adapter" in msg
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(nccl=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with patch.object(plugin, "_wire_multi_turn") as wire:
plugin.post_trainer_create(cfg, trainer)
wire.assert_called_once()
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(http_full=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(NotImplementedError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "HTTP weight sync" in str(ctx.exception)
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps()
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "init_communicator" in msg
assert "http_update_weights" in msg
def test_unprobed_caps_raises_with_probe_failure_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
# Plugin._vllm_caps left as default-None: the post_trainer_create
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
# probed=False, so the error path should mention probing.
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "could not probe" in str(ctx.exception)
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
@@ -775,15 +452,19 @@ class TestNemoGymE2E(unittest.TestCase):
trainer = self._make_mock_trainer()
producer._trainer = trainer
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
# copies of each unique prompt — one entry per rollout.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
# Mock the prompt iterator (returns a batch of 1 input)
producer._prompt_iter = iter(
[
[
{
"prompt": [{"role": "user", "content": "Play Wordle!"}],
}
]
]
)
producer._prompt_dl = [
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
# Call produce
result = producer.produce(model=MagicMock(), global_step=1)
@@ -849,13 +530,10 @@ class TestNemoGymE2E(unittest.TestCase):
producer._request_timeout = 30
producer._num_generations = 2
producer._trainer = self._make_mock_trainer()
# RepeatSampler pre-expands by num_generations=2.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play!"}]},
{"prompt": [{"role": "user", "content": "Play!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
producer._prompt_iter = iter(
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
)
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
result = producer.produce(model=MagicMock(), global_step=1)

View File

@@ -38,30 +38,6 @@ def _reference_norm_noscale(x, eps):
return norm(x)
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
the trailing columns passed through unchanged. Mirrors how Llama-style
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
"""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
n_rot = cos.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
if n_rot == D:
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
x_rot = normed[..., :n_rot]
x_pass = normed[..., n_rot:]
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
return torch.cat([rotated, x_pass], dim=-1)
@pytest.fixture(
params=[
(2, 64, 32, 256), # sliding window layer shape
@@ -218,172 +194,6 @@ class TestFusedRMSNormRoPEBackward:
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
class TestFusedRMSNormRoPEPartialRotary:
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
Compares against the original primitives (`Gemma4RMSNorm` +
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
columns passed through. Without the kernel fix this used to crash with
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
"""
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
(2, 16, 4, 64, 16), # quarter rotary
(2, 32, 8, 128, 64), # half rotary, larger heads
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
],
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
)
def test_forward_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
assert y_fused.shape == y_ref.shape == (B, S, H, D)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"partial rotary forward cosine_sim={cos_sim:.6f} "
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
)
# The pass-through tail must equal the reference RMSNorm output bit-
# for-bit (any deviation would mean the kernel is touching it with a
# spurious rotation, which is the original bug class).
torch.testing.assert_close(
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
)
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference backward via the original primitives
x_ref = x_data.clone().requires_grad_(True)
w_ref = weight_init.clone()
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
y_ref.sum().backward()
# Fused backward
x_fused = x_data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
# rotary applied to the rotated slice.
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
normed = norm_ref(x_data)
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
y_ref.sum().backward()
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
)
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
"""Regression: passing cos/sin with shape == head_dim must still
match the full-rotary reference (the partial-rotary code path must
not perturb the existing full-rotary output)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 2, 16, 4, 64
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
def test_validation_errors(self):
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
Triton crash deeper in the kernel)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 4, 2, 64
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# n_rot > head_dim
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="cannot exceed head_dim"):
fused_rms_norm_rope(x, w, cos_big, sin_big)
# cos/sin last-dim mismatch
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="same last dim"):
fused_rms_norm_rope(x, w, cos, sin)
# odd rotary dim
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="must be even"):
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
class TestFusedRMSNormNoScale:
"""Tests for v_norm (RMSNorm without learnable scale)."""

View File

@@ -1,219 +0,0 @@
"""Tests for the Gemma 4 fused-attention monkey-patch.
These tests exercise the patched ``Gemma4TextAttention.forward`` against
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
exercised end-to-end (this is the bug originally documented in
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
The full-model forward also pins that the fused forward keeps accepting
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
installed transformers version — so any future signature drift on
upstream's side trips a clear failure here instead of a confusing
TypeError deep in a training run.
"""
import pytest
import torch
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
]
pytest.importorskip(
"transformers.models.gemma4",
reason="fused_attn patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_attention():
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
so the monkey-patch does not leak across the suite."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
saved = Gemma4TextAttention.forward
yield Gemma4TextAttention
Gemma4TextAttention.forward = saved
def _build_hybrid_config():
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
layer with proportional rope and partial_rotary_factor=0.25. This is
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
enough to fit on any GPU."""
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
global_head_dim=64,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
rope_parameters={
"sliding_attention": {
"rope_type": "default",
"rope_theta": 10000.0,
},
"full_attention": {
"rope_type": "proportional",
"rope_theta": 1000000.0,
"partial_rotary_factor": 0.25,
},
},
)
cfg._attn_implementation = "sdpa"
return cfg
def _build_model(seed=0):
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
torch.manual_seed(seed)
cfg = _build_hybrid_config()
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
class TestFusedAttnSignature:
"""The fused forward must accept the same call shape as
``Gemma4TextDecoderLayer`` produces in the installed transformers
version. Any signature drift surfaces here as a TypeError."""
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
"""Run a model forward that exercises the real
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
the fused patch installed."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
patch_gemma4_fused_attn()
with torch.no_grad():
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
assert out.shape == (2, 16, 64)
assert torch.isfinite(out).all()
class TestFusedAttnPerLayerCorrectness:
"""Compare the patched attention layer to the stock implementation
on a single forward call. This isolates the fused kernel correctness
from cross-layer numerical drift."""
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
installed) for one layer and return the output."""
attn = model.layers[layer_idx].self_attn
layer_type = model.config.layer_types[layer_idx]
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
out, _ = attn(
hidden_states=hidden_states,
position_embeddings=(cos, sin),
attention_mask=None,
shared_kv_states={},
)
return out
@pytest.mark.parametrize(
"layer_idx",
[0, 1],
ids=["sliding_head32", "global_head64_proportional"],
)
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=1)
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
with torch.no_grad():
ref = self._run_attention(m, layer_idx, hs, pos)
patch_gemma4_fused_attn()
with torch.no_grad():
got = self._run_attention(m, layer_idx, hs, pos)
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
)
# bf16 precision: a few millis of absolute drift per element is
# acceptable for a Q/K/V projection pipeline. Anything larger is
# a real bug.
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
class TestFusedAttnFullModel:
"""End-to-end model forward + backward through both layer types."""
def test_full_forward_matches_stock(self, restore_gemma4_attention):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=2)
ids = torch.randint(0, 128, (2, 32), device="cuda")
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
with torch.no_grad():
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
patch_gemma4_fused_attn()
with torch.no_grad():
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
# accumulates a small amount of numerical drift; we just want to
# pin that the two paths are computing the same function.
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
for both the sliding and proportional-rope layers."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=3).train()
patch_gemma4_fused_attn()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
out.sum().backward()
# Both layers must accumulate gradients on q_norm.weight and
# k_norm.weight — that proves the fused kernel ran the backward.
for i, layer in enumerate(m.layers[:2]):
attn = layer.self_attn
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
assert attn.q_norm.weight.grad.isfinite().all()
assert attn.k_norm.weight.grad.isfinite().all()
assert attn.q_norm.weight.grad.abs().sum() > 0
assert attn.k_norm.weight.grad.abs().sum() > 0

View File

@@ -1,343 +0,0 @@
"""Tests for the Gemma 4 hybrid-attention mask fix.
These tests pin the single critical behavior: after installing the patch,
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
the underlying mask builder regardless of what the caller's config says.
This is what keeps full-attention (head_dim=512) global layers from
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
the 2D FA2 mask that would be built from the model-level config.
The tests use a mocked ``create_causal_mask`` so they don't have to load
a real 26B Gemma 4 model or even have access to its weights. What matters
for the bug fix is which config is handed to the mask factory, not the
factory's actual output.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
pytest.importorskip(
"transformers.models.gemma4",
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_module():
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
each test so patch state doesn't leak across the suite."""
from transformers.models.gemma4 import modeling_gemma4
saved = modeling_gemma4.create_causal_mask
yield modeling_gemma4
modeling_gemma4.create_causal_mask = saved
# Reset the module-level flag so the next test can re-install cleanly.
from axolotl.monkeypatch import gemma4_hybrid_mask
gemma4_hybrid_mask._PATCH_APPLIED = False
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
original = modeling_gemma4.create_causal_mask
assert patch_gemma4_hybrid_mask() is True
assert modeling_gemma4.create_causal_mask is not original
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
"patched wrapper must expose the original reference for teardown"
)
def test_patch_is_idempotent(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
wrapper_first = modeling_gemma4.create_causal_mask
# Second call must not re-wrap the already-wrapped function (which
# would leak the original reference through a chain of wrappers).
patch_gemma4_hybrid_mask()
wrapper_second = modeling_gemma4.create_causal_mask
assert wrapper_first is wrapper_second
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
"""Core invariant: when the patched wrapper is called with a config
that says ``flash_attention_2``, the underlying mask factory receives
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
Without this, the full-attention global layers get a 2D FA2 mask and
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
"""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
# Swap in a mock BEFORE installing the patch so the wrapper captures
# it as the "original". The mock records every call so we can inspect
# what config got passed through.
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
# Caller-supplied config says FA2 (that's the model-level setting).
caller_config = SimpleNamespace(
_attn_implementation="flash_attention_2",
head_dim=512,
some_other_attr="preserved",
)
result = modeling_gemma4.create_causal_mask(
caller_config,
inputs_embeds=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
)
# Wrapper returned whatever the mock returned — no transformation of
# the result itself.
assert result == "mask_4d"
# The mock was called exactly once with a config whose
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
assert mock_factory.call_count == 1
(passed_config, *_), passed_kwargs = mock_factory.call_args
assert passed_config._attn_implementation == "sdpa"
# The wrapper must NOT mutate the caller's config in place — other
# mask builders (e.g. create_sliding_window_causal_mask) read from
# the same config and must still see fa2.
assert caller_config._attn_implementation == "flash_attention_2"
# Other attributes on the config must be preserved so the underlying
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
assert passed_config.head_dim == 512
assert passed_config.some_other_attr == "preserved"
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
"""The wrapper must forward positional + keyword args to the original
unchanged, so transformers' own call-site in Gemma4TextModel.forward
keeps working across minor transformers-version signature drift."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
mock_factory = MagicMock(return_value="mask")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
modeling_gemma4.create_causal_mask(
caller_config,
"positional_arg",
inputs_embeds="embeds",
attention_mask="mask_2d",
past_key_values="cache",
position_ids="positions",
or_mask_function="or_fn",
)
args, kwargs = mock_factory.call_args
# First positional (after config override) is preserved.
assert args[1] == "positional_arg"
# All kwargs are forwarded untouched.
assert kwargs["inputs_embeds"] == "embeds"
assert kwargs["attention_mask"] == "mask_2d"
assert kwargs["past_key_values"] == "cache"
assert kwargs["position_ids"] == "positions"
assert kwargs["or_mask_function"] == "or_fn"
def test_unpatch_restores_original(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import (
patch_gemma4_hybrid_mask,
unpatch_gemma4_hybrid_mask,
)
sentinel = MagicMock(name="original")
modeling_gemma4.create_causal_mask = sentinel
patch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is not sentinel
unpatch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is sentinel
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
# Should be a no-op, no exception.
unpatch_gemma4_hybrid_mask()
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
"""Only ``create_causal_mask`` is overridden — the sliding-window
factory must remain bound to its original to preserve FA2 masks for
the sliding-attention layers. If we accidentally patch both, the
sliding layers get SDPA format and lose the FA2 speedup."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
pytest.skip("transformers version has no create_sliding_window_causal_mask")
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
patch_gemma4_hybrid_mask()
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
assert sliding_after is sliding_before
# ---------------------------------------------------------------------------
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
#
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
# model with 2 layers (one sliding, one full_attention), apply the hybrid
# attention path end-to-end, and run a forward pass with a padded
# attention_mask at a long-ish seq len. The invariant we're pinning is that
# the full_attention layer does not crash with the
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
# tokens in the FSDP2 training run.
# ---------------------------------------------------------------------------
def _build_tiny_gemma4_text_model():
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
import torch
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Caller-supplied attn impl simulates the pilot config (fa2 at model
# level). The hybrid patch is what makes this survive long context.
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
torch.manual_seed(42)
model = Gemma4TextModel(cfg).eval()
return model, cfg
def _apply_hybrid_attn_inline(model, cfg):
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
to a model, without needing a full PatchManager / pydantic cfg."""
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
for layer_idx, layer in enumerate(model.layers):
if cfg.layer_types[layer_idx] != "sliding_attention":
attn = getattr(layer, "self_attn", None)
if attn is not None and hasattr(attn, "config"):
sdpa_cfg = copy.copy(attn.config)
sdpa_cfg._attn_implementation = "sdpa"
attn.config = sdpa_cfg
patch_gemma4_hybrid_mask()
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
Gemma4TextModel runs a forward at long context (1024 tokens) with
real padding in the attention mask, producing the expected output
shape. This exercises the actual code path that crashed the pilot
without needing a real 26B checkpoint or CUDA."""
import torch
model, cfg = _build_tiny_gemma4_text_model()
_apply_hybrid_attn_inline(model, cfg)
B, S = 2, 1024
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
attn_mask = torch.ones(B, S, dtype=torch.long)
# Pad positions in the second row. Without padding, SDPA falls back to
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
# mask to exercise the actual bug site.
attn_mask[1, S // 2 :] = 0
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attn_mask)
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
assert torch.isfinite(out.last_hidden_state).all()
def test_patched_create_causal_mask_returns_4d_for_real_config(
restore_gemma4_module,
):
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
and verify the returned mask is a 4D tensor — which is the shape the
SDPA-patched global layers need. Without the patch and with a
caller-supplied FA2 config this would return a 2D mask and the layer
would crash at long context."""
import torch
from transformers.cache_utils import DynamicCache
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
modeling_gemma4 = restore_gemma4_module
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Simulate the pilot: caller says flash_attention_2, but global layers
# were switched to SDPA per-layer. Without the patch, create_causal_mask
# would return an FA2 2D mask here and the SDPA layer would crash.
cfg._attn_implementation = "flash_attention_2"
B, S = 2, 1024
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
attention_mask = torch.ones((B, S), dtype=torch.long)
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
past_key_values = DynamicCache(config=cfg)
mask = modeling_gemma4.create_causal_mask(
config=cfg,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
assert mask is not None
assert isinstance(mask, torch.Tensor)
assert mask.dim() == 4, (
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
f"shape={tuple(mask.shape)}. The full_attention global layers need "
"this shape or they crash at long context."
)
assert mask.shape[0] == B
assert mask.shape[-1] == S
assert mask.shape[-2] == S
# Caller's config must be untouched — other code paths still read it.
assert cfg._attn_implementation == "flash_attention_2"

View File

@@ -487,70 +487,3 @@ class TestDatasetPreparation:
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_load_dataset_with_str_json_data(self, tokenizer):
"""
Test loading datasets where data is stored as str JSON instead of list of dicts.
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
import json
str_json_ds = Dataset.from_list(
[
{
"messages": json.dumps(
[
{"role": "user", "content": "Hello how are you?"},
{
"role": "assistant",
"content": "I am doing good thanks",
},
]
)
},
{
"messages": json.dumps(
[
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
]
)
},
]
)
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
str_json_ds.to_parquet(tmp_ds_path)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 512,
"datasets": [
{
"path": str(tmp_ds_path),
"name": "test_str_json",
"type": "chat_template",
"field_messages": "messages",
"message_field_role": "role",
"message_field_content": "content",
},
],
"dataset_num_proc": 4,
}
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
assert len(dataset[0]["input_ids"]) > 0

View File

@@ -133,108 +133,3 @@ class TestRevisionParameter:
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
"processor_kwargs": {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
},
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("image_seq_length") == 1120
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_processor_kwargs_when_unset(
self, mock_auto_processor
):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "image_seq_length" not in call_kwargs.kwargs
assert "max_soft_tokens" not in call_kwargs.kwargs
def test_processor_kwargs_schema_rejects_revision(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="revision"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"revision": "abc123"},
)
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
import pytest
from axolotl.utils.schemas.model import ModelInputConfig
with pytest.raises(ValueError, match="trust_remote_code"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"trust_remote_code": True},
)
def test_processor_kwargs_schema_accepts_valid_keys(self):
from axolotl.utils.schemas.model import ModelInputConfig
cfg = ModelInputConfig(
base_model="some-model",
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
)
assert cfg.processor_kwargs == {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
}
def test_processor_kwargs_schema_accepts_none_and_empty(self):
from axolotl.utils.schemas.model import ModelInputConfig
assert ModelInputConfig(base_model="x").processor_kwargs is None
assert (
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
)
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
cfg = min_base_cfg | DictDefault(
tokenizer_use_mistral_common=True,
processor_kwargs={"image_seq_length": 1120},
)
with pytest.raises(ValueError, match="processor_kwargs"):
validate_config(cfg)

View File

@@ -5,8 +5,6 @@ Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
- GRPO: generation batch size must be divisible by num_generations,
num_generations >= 2, and effective_gbs >= num_generations * world_size
"""
import pytest
@@ -119,136 +117,3 @@ class TestLoraTargetModulesRegexValidator:
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
class TestGRPOBatchSizeValidator:
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
These call the @model_validator(mode="before") classmethod directly on a
plain dict — same input shape it receives during full Pydantic validation,
just without dragging in unrelated fields (datasets / model loading / etc.)
that aren't relevant to what's under test. The validator is registered on
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
same code path ``axolotl train`` exercises.
"""
@staticmethod
def _check(data):
from axolotl.utils.schemas.validation import RLValidationMixin
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
def test_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 4},
}
# Should return data unchanged (no exception)
out = self._check(data)
assert out["trl"]["num_generations"] == 4
def test_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="num_generations"):
self._check(data)
def test_non_divisible_error_includes_fix_hint(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
self._check(data)
def test_num_generations_one_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 1},
}
with pytest.raises(ValueError, match=r"num_generations >= 2"):
self._check(data)
def test_explicit_generation_batch_size_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 8},
}
out = self._check(data)
assert out["trl"]["generation_batch_size"] == 8
def test_explicit_generation_batch_size_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 6},
}
with pytest.raises(ValueError, match="trl.generation_batch_size"):
self._check(data)
def test_non_grpo_skips_check(self):
# Anything other than rl=grpo should pass through untouched, even
# with non-divisible batch sizes — they're irrelevant to other RL
# methods that don't use group-relative advantages.
data = {
"rl": "dpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
assert self._check(data) is data
def test_no_rl_set_skips_check(self):
data = {
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
}
assert self._check(data) is data
def test_grpo_without_num_generations_skips_check(self):
# If num_generations isn't set, TRL uses its own default — we don't
# have enough info to validate, so the validator must short-circuit
# rather than guess.
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {},
}
out = self._check(data)
assert out["rl"] == "grpo"
def test_multi_rank_group_size_check(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4, # gbs=4
"world_size": 2, # need gbs >= 4*2 = 8
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match=r"world_size=2"):
self._check(data)
def test_multi_rank_group_size_satisfied(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
"world_size": 2,
"trl": {"num_generations": 4},
}
out = self._check(data)
assert out["gradient_accumulation_steps"] == 8