Compare commits
2 Commits
main
...
kernelize-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8495c79fb1 | ||
|
|
9a0d3016df |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,11 +31,10 @@ PRs are **greatly welcome**!
|
|||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
```bash
|
```bash
|
||||||
# Install axolotl + dev and test dependencies
|
# Install axolotl + dev and test dependencies from lockfile
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
|
||||||
# test
|
# test
|
||||||
|
|||||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,6 +30,14 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -160,6 +168,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,6 +18,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -174,6 +180,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
|
|||||||
| Method | Config Key | When to Use |
|
| Method | Config Key | When to Use |
|
||||||
|--------|-----------|-------------|
|
|--------|-----------|-------------|
|
||||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||||
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN pip uninstall -y causal_conv1d
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="optimizers,ray"; \
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
|||||||
@@ -58,3 +58,19 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
|
# Map Python version (e.g., 3.12 -> cp312)
|
||||||
|
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||||
|
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||||
|
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||||
|
# Map architecture
|
||||||
|
case "$TARGETARCH" in \
|
||||||
|
amd64) ARCH_TAG="x86_64" ;; \
|
||||||
|
arm64) ARCH_TAG="aarch64" ;; \
|
||||||
|
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||||
|
esac && \
|
||||||
|
WHL_VERSION="v0.7.16" && \
|
||||||
|
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||||
|
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||||
|
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||||
|
rm "${WHL_FILE}"
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="optimizers,ray"; \
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
|||||||
@@ -38,3 +38,20 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Map Python version (e.g., 3.12 -> cp312)
|
||||||
|
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||||
|
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||||
|
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||||
|
LINUX_TAG="manylinux_" && \
|
||||||
|
# Map architecture
|
||||||
|
case "$TARGETARCH" in \
|
||||||
|
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
||||||
|
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
||||||
|
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||||
|
esac && \
|
||||||
|
WHL_VERSION="v0.7.16" && \
|
||||||
|
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
||||||
|
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||||
|
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||||
|
rm "${WHL_FILE}"
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
|||||||
|
|
||||||
1. Paired preference data (chosen + rejected)?
|
1. Paired preference data (chosen + rejected)?
|
||||||
- Default → `rl: dpo`
|
- Default → `rl: dpo`
|
||||||
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
- Overfitting → `rl: ipo`
|
||||||
- VRAM-limited → `rl: orpo` (no ref model)
|
- VRAM-limited → `rl: orpo` (no ref model)
|
||||||
- Length-sensitive → `rl: simpo` (no ref model)
|
- Length-sensitive → `rl: simpo` (no ref model)
|
||||||
2. Only binary labels (good/bad)? → `rl: kto`
|
2. Only binary labels (good/bad)? → `rl: kto`
|
||||||
|
|||||||
@@ -77,9 +77,8 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -219,9 +218,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv venv --no-project --relocatable
|
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -10,16 +10,13 @@ This section describes the different Docker images that are released by AxolotlA
|
|||||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
### Switch to the `-uv` images
|
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
||||||
|
:::
|
||||||
|
|
||||||
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
|
::: {.callout-tip}
|
||||||
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
||||||
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
|
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||||
same format as their non-uv counterparts.
|
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
||||||
|
|
||||||
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
|
|
||||||
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
|
|
||||||
the uv image.
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -88,7 +85,7 @@ Tags examples:
|
|||||||
- `main-py3.12-cu130-2.10.0`
|
- `main-py3.12-cu130-2.10.0`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20260315-py3.11-cu128-2.9.1`
|
- `main-20260315-py3.11-cu128-2.9.1`
|
||||||
- `0.16.1`
|
- `0.12.0`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: vLLM is not working with Axolotl**
|
**Q: vLLM is not working with Axolotl**
|
||||||
|
|
||||||
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
|
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||||
|
|
||||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.9.1
|
- PyTorch ≥2.9.0
|
||||||
|
|
||||||
## Installation {#sec-installation}
|
## Installation {#sec-installation}
|
||||||
|
|
||||||
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
|
|||||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv
|
uv venv --no-project --relocatable
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Edge/Development Build {#sec-edge-build}
|
### Edge/Development Build {#sec-edge-build}
|
||||||
@@ -49,11 +49,12 @@ For the latest features between releases:
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv
|
uv sync --extra flash-attn --extra deepspeed
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation -e '.[deepspeed]'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
||||||
|
|
||||||
### Docker {#sec-docker}
|
### Docker {#sec-docker}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
@@ -131,11 +132,11 @@ source $HOME/.local/bin/env
|
|||||||
|
|
||||||
# Create a fresh venv (recommended for a clean start)
|
# Create a fresh venv (recommended for a clean start)
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv
|
uv venv --no-project --relocatable
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
# Reinstall axolotl
|
# Reinstall axolotl
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using pip (Alternative) {#sec-pip}
|
## Using pip (Alternative) {#sec-pip}
|
||||||
@@ -150,13 +151,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
|||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
For editable/development installs:
|
For editable/development installs:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting {#sec-troubleshooting}
|
## Troubleshooting {#sec-troubleshooting}
|
||||||
|
|||||||
@@ -320,10 +320,8 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: dpo
|
rl: ipo
|
||||||
dpo_loss_type: ["ipo"]
|
|
||||||
```
|
```
|
||||||
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
|
||||||
|
|
||||||
### ORPO
|
### ORPO
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run one of the finetuning examples below.
|
2. Run one of the finetuning examples below.
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -36,7 +36,12 @@
|
|||||||
"id": "msOCO4NRmRLa"
|
"id": "msOCO4NRmRLa"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
"source": [
|
||||||
|
"%%capture\n",
|
||||||
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
|
||||||
# Install Cut Cross Entropy
|
# Install Cut Cross Entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install an extra dependency:
|
2. Install an extra dependency:
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Please install the below.
|
2. Please install the below.
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core ML stack
|
# Core ML stack
|
||||||
"torch>=2.9.1",
|
"torch>=2.6.0",
|
||||||
"packaging==26.0",
|
"packaging==26.0",
|
||||||
"huggingface_hub>=1.1.7",
|
"huggingface_hub>=1.1.7",
|
||||||
"peft>=0.19.1,<0.20.0",
|
"peft>=0.19.1,<0.20.0",
|
||||||
@@ -79,7 +79,7 @@ dependencies = [
|
|||||||
# Platform-specific (Linux only)
|
# Platform-specific (Linux only)
|
||||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||||
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
||||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||||
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||||
|
|
||||||
|
|||||||
479
scripts/build_scattermoe_lora_kernel.py
Normal file
479
scripts/build_scattermoe_lora_kernel.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Build a disposable Hugging Face Kernel Hub package for ScatterMoE LoRA.
|
||||||
|
|
||||||
|
This script does not move or edit the in-tree Axolotl kernel sources. It copies
|
||||||
|
``src/axolotl/integrations/kernels/libs/scattermoe_lora`` into an ignored
|
||||||
|
build directory and emits a universal HF kernels project that can be pushed to
|
||||||
|
the Hub.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import fnmatch
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from importlib import metadata
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
PACKAGE_NAME = "scattermoe_lora"
|
||||||
|
BUILD_VARIANT = "torch-universal"
|
||||||
|
DEFAULT_REPO_ID = "kernels-community/scattermoe-lora"
|
||||||
|
HF_REPO_TYPE = "kernel"
|
||||||
|
HF_KERNEL_URL_PREFIX = "https://hf.co/kernels"
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
DEFAULT_SOURCE_DIR = (
|
||||||
|
REPO_ROOT / "src" / "axolotl" / "integrations" / "kernels" / "libs" / PACKAGE_NAME
|
||||||
|
)
|
||||||
|
DEFAULT_OUTPUT_DIR = REPO_ROOT / "build" / "hf-kernels" / PACKAGE_NAME
|
||||||
|
|
||||||
|
EXCLUDED_DIRS = {
|
||||||
|
"__pycache__",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
".ruff_cache",
|
||||||
|
}
|
||||||
|
EXCLUDED_FILE_PATTERNS = {
|
||||||
|
"*.pyc",
|
||||||
|
"*.pyo",
|
||||||
|
"*.so",
|
||||||
|
".DS_Store",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEXT_REPLACEMENTS = {
|
||||||
|
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import": (
|
||||||
|
"from .selective_dequant import"
|
||||||
|
),
|
||||||
|
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import": (
|
||||||
|
"from .selective_dequant_kernel import"
|
||||||
|
),
|
||||||
|
"from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import": (
|
||||||
|
"from .ops import"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Copy Axolotl's ScatterMoE LoRA Triton kernels into a disposable "
|
||||||
|
"HF Kernel Hub universal package."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-dir",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_SOURCE_DIR,
|
||||||
|
help=f"ScatterMoE LoRA source package to copy. Default: {DEFAULT_SOURCE_DIR}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_OUTPUT_DIR,
|
||||||
|
help=f"Destination build/dist directory. Default: {DEFAULT_OUTPUT_DIR}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
default=DEFAULT_REPO_ID,
|
||||||
|
help=f"HF Hub repo id to write into build.toml. Default: {DEFAULT_REPO_ID}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--version",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Kernel major version written to build.toml and metadata.json.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete the output directory first if it already exists.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-source-layout",
|
||||||
|
action="store_true",
|
||||||
|
help="Only write the shippable build/ tree, not torch-ext/ sources.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Upload the generated universal kernel package with huggingface_hub. "
|
||||||
|
"This bypasses kernel-builder and is intended for pure Python/Triton "
|
||||||
|
"universal kernels."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--private",
|
||||||
|
action="store_true",
|
||||||
|
help="Create the HF Hub repo as private when used with --upload.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-version-branch",
|
||||||
|
action="store_true",
|
||||||
|
help="With --upload, only upload main and skip the v<version> branch.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def should_skip_file(path: Path) -> bool:
|
||||||
|
return any(
|
||||||
|
fnmatch.fnmatch(path.name, pattern) for pattern in EXCLUDED_FILE_PATTERNS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_source_files(source_dir: Path) -> list[Path]:
|
||||||
|
files: list[Path] = []
|
||||||
|
for root, dirs, filenames in os.walk(source_dir):
|
||||||
|
dirs[:] = sorted(d for d in dirs if d not in EXCLUDED_DIRS)
|
||||||
|
for filename in sorted(filenames):
|
||||||
|
path = Path(root) / filename
|
||||||
|
if not should_skip_file(path):
|
||||||
|
files.append(path)
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def content_hash(source_dir: Path) -> str:
|
||||||
|
digest = hashlib.sha1()
|
||||||
|
for path in iter_source_files(source_dir):
|
||||||
|
rel = path.relative_to(source_dir).as_posix()
|
||||||
|
digest.update(rel.encode("utf-8"))
|
||||||
|
digest.update(b"\0")
|
||||||
|
digest.update(path.read_bytes())
|
||||||
|
digest.update(b"\0")
|
||||||
|
return digest.hexdigest()[:10]
|
||||||
|
|
||||||
|
|
||||||
|
def git_revision() -> str:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "rev-parse", "--short", "HEAD"],
|
||||||
|
cwd=REPO_ROOT,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
except (OSError, subprocess.CalledProcessError):
|
||||||
|
return "unknown"
|
||||||
|
return result.stdout.strip() or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def transform_python_source(text: str, rel_path: Path, op_namespace: str) -> str:
|
||||||
|
for old, new in TEXT_REPLACEMENTS.items():
|
||||||
|
text = text.replace(old, new)
|
||||||
|
|
||||||
|
if rel_path.as_posix() == "gemma4_experts.py":
|
||||||
|
text = text.replace(
|
||||||
|
" from axolotl.integrations.kernels.constants import resolve_experts_class",
|
||||||
|
(
|
||||||
|
" raise RuntimeError(\n"
|
||||||
|
' "patch_gemma4_scattermoe is only available from the in-tree Axolotl "\n'
|
||||||
|
' "integration. Use register_scattermoe_experts() with the standalone "\n'
|
||||||
|
' "HF kernel package."\n'
|
||||||
|
" )"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return text.replace("scattermoe::", f"{op_namespace}::")
|
||||||
|
|
||||||
|
|
||||||
|
def copy_package(source_dir: Path, package_dir: Path, op_namespace: str) -> None:
|
||||||
|
for source in iter_source_files(source_dir):
|
||||||
|
rel_path = source.relative_to(source_dir)
|
||||||
|
destination = package_dir / rel_path
|
||||||
|
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if source.suffix == ".py":
|
||||||
|
text = source.read_text(encoding="utf-8")
|
||||||
|
text = transform_python_source(text, rel_path, op_namespace)
|
||||||
|
destination.write_text(text, encoding="utf-8")
|
||||||
|
else:
|
||||||
|
shutil.copy2(source, destination)
|
||||||
|
|
||||||
|
write_ops_module(package_dir / "_ops.py", op_namespace)
|
||||||
|
|
||||||
|
|
||||||
|
def write_ops_module(path: Path, op_namespace: str) -> None:
|
||||||
|
path.write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"import torch",
|
||||||
|
"",
|
||||||
|
f"ops = torch.ops.{op_namespace}",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"def add_op_namespace_prefix(op_name: str) -> str:",
|
||||||
|
f' return f"{op_namespace}::{{op_name}}"',
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_build_toml(path: Path, repo_id: str, version: int) -> None:
|
||||||
|
lines = [
|
||||||
|
"[general]",
|
||||||
|
f'name = "{PACKAGE_NAME}"',
|
||||||
|
"universal = true",
|
||||||
|
f"version = {version}",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
if repo_id:
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
"[general.hub]",
|
||||||
|
f'repo-id = "{repo_id}"',
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
path.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def write_flake(path: Path) -> None:
|
||||||
|
path.write_text(
|
||||||
|
"""{
|
||||||
|
description = "Flake for scattermoe_lora kernel";
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
builder.url = "github:huggingface/kernels";
|
||||||
|
};
|
||||||
|
|
||||||
|
outputs =
|
||||||
|
{
|
||||||
|
self,
|
||||||
|
builder,
|
||||||
|
}:
|
||||||
|
builder.lib.genKernelFlakeOutputs {
|
||||||
|
inherit self;
|
||||||
|
path = ./.;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_readme(path: Path, repo_id: str, source_hash: str, op_namespace: str) -> None:
|
||||||
|
repo_display = repo_id or "<your-org>/scattermoe-lora"
|
||||||
|
path.write_text(
|
||||||
|
f"""---
|
||||||
|
library_name: kernels
|
||||||
|
license: apache-2.0
|
||||||
|
tags:
|
||||||
|
- kernel
|
||||||
|
- kernels
|
||||||
|
---
|
||||||
|
|
||||||
|
# ScatterMoE LoRA
|
||||||
|
|
||||||
|
Standalone Hugging Face Kernel Hub package for Axolotl's ScatterMoE LoRA Triton kernels.
|
||||||
|
|
||||||
|
This package is generated from Axolotl's in-tree `scattermoe_lora` sources and is exported as a universal kernel because the implementation is Python/Triton rather than a precompiled C++/CUDA extension.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
scattermoe_lora = get_kernel("{repo_display}")
|
||||||
|
```
|
||||||
|
|
||||||
|
Export metadata:
|
||||||
|
|
||||||
|
- source package: `src/axolotl/integrations/kernels/libs/scattermoe_lora`
|
||||||
|
- source revision: `{git_revision()}`
|
||||||
|
- source content hash: `{source_hash}`
|
||||||
|
- torch custom op namespace: `{op_namespace}`
|
||||||
|
|
||||||
|
The generated `build/torch-universal/{PACKAGE_NAME}` directory is the shippable Hub artifact. `torch-ext/{PACKAGE_NAME}` is included so `kernel-builder build-and-copy` can regenerate the universal build tree if desired.
|
||||||
|
""",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def write_metadata(path: Path, version: int) -> None:
|
||||||
|
path.write_text(
|
||||||
|
json.dumps({"version": version}, indent=2, sort_keys=True) + "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_output_dir(output_dir: Path, force: bool) -> None:
|
||||||
|
if output_dir.exists():
|
||||||
|
if not force:
|
||||||
|
raise FileExistsError(
|
||||||
|
f"{output_dir} already exists. Re-run with --force to replace it."
|
||||||
|
)
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_package(args: argparse.Namespace) -> Path:
|
||||||
|
source_dir = args.source_dir.resolve()
|
||||||
|
output_dir = args.output_dir.resolve()
|
||||||
|
|
||||||
|
if not source_dir.is_dir():
|
||||||
|
raise FileNotFoundError(f"source package does not exist: {source_dir}")
|
||||||
|
if not (source_dir / "__init__.py").is_file():
|
||||||
|
raise FileNotFoundError(f"source package is missing __init__.py: {source_dir}")
|
||||||
|
|
||||||
|
source_hash = content_hash(source_dir)
|
||||||
|
op_namespace = f"_{PACKAGE_NAME}_{source_hash}"
|
||||||
|
|
||||||
|
prepare_output_dir(output_dir, args.force)
|
||||||
|
|
||||||
|
write_build_toml(output_dir / "build.toml", args.repo_id, args.version)
|
||||||
|
write_flake(output_dir / "flake.nix")
|
||||||
|
write_readme(output_dir / "README.md", args.repo_id, source_hash, op_namespace)
|
||||||
|
|
||||||
|
if not args.no_source_layout:
|
||||||
|
copy_package(source_dir, output_dir / "torch-ext" / PACKAGE_NAME, op_namespace)
|
||||||
|
|
||||||
|
build_package_dir = output_dir / "build" / BUILD_VARIANT / PACKAGE_NAME
|
||||||
|
copy_package(source_dir, build_package_dir, op_namespace)
|
||||||
|
write_metadata(build_package_dir.parent / "metadata.json", args.version)
|
||||||
|
|
||||||
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
def upload_package(args: argparse.Namespace, output_dir: Path) -> None:
|
||||||
|
if not args.repo_id:
|
||||||
|
raise ValueError("--repo-id is required when using --upload")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import HfApi, constants as hf_constants
|
||||||
|
except ImportError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
"--upload requires huggingface_hub. Install it or run the upload "
|
||||||
|
"manually with the Hugging Face CLI."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
hub_version = metadata.version("huggingface_hub")
|
||||||
|
except metadata.PackageNotFoundError:
|
||||||
|
hub_version = "unknown"
|
||||||
|
|
||||||
|
accepted_repo_types = getattr(
|
||||||
|
hf_constants,
|
||||||
|
"REPO_TYPES_WITH_KERNEL",
|
||||||
|
getattr(hf_constants, "REPO_TYPES", ()),
|
||||||
|
)
|
||||||
|
if HF_REPO_TYPE not in accepted_repo_types:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Your huggingface_hub installation does not support "
|
||||||
|
f"repo_type={HF_REPO_TYPE!r} (found huggingface_hub {hub_version}). "
|
||||||
|
f"Upgrade this interpreter with: {sys.executable} -m pip install --upgrade "
|
||||||
|
"'huggingface_hub>=1.10.0'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# huggingface_hub 1.11.0 has partial kernel support: create_repo accepts
|
||||||
|
# "kernel", but upload_folder/create_commit still validate against the
|
||||||
|
# older REPO_TYPES list. Extend it in-process so those helpers use the
|
||||||
|
# /api/kernels/... endpoints until upstream broadens that check.
|
||||||
|
if HF_REPO_TYPE not in hf_constants.REPO_TYPES:
|
||||||
|
hf_constants.REPO_TYPES.append(HF_REPO_TYPE)
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
try:
|
||||||
|
repo_id = api.create_repo(
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
repo_type=HF_REPO_TYPE,
|
||||||
|
private=args.private,
|
||||||
|
exist_ok=True,
|
||||||
|
).repo_id
|
||||||
|
except ValueError as exc:
|
||||||
|
if "Invalid repo type" in str(exc):
|
||||||
|
raise RuntimeError(
|
||||||
|
"huggingface_hub rejected repo_type='kernel'. "
|
||||||
|
f"This usually means the command is running with an older Hub "
|
||||||
|
f"client than expected (found huggingface_hub {hub_version} at "
|
||||||
|
f"{sys.executable}). Upgrade with: {sys.executable} -m pip "
|
||||||
|
"install --upgrade 'huggingface_hub>=1.10.0'"
|
||||||
|
) from exc
|
||||||
|
raise
|
||||||
|
|
||||||
|
delete_patterns = [
|
||||||
|
"build/**",
|
||||||
|
"torch-ext/**",
|
||||||
|
"build.toml",
|
||||||
|
"flake.nix",
|
||||||
|
"README.md",
|
||||||
|
]
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=HF_REPO_TYPE,
|
||||||
|
folder_path=output_dir,
|
||||||
|
revision="main",
|
||||||
|
delete_patterns=delete_patterns,
|
||||||
|
commit_message="Upload ScatterMoE LoRA universal kernel",
|
||||||
|
)
|
||||||
|
print(f"Uploaded main branch: {HF_KERNEL_URL_PREFIX}/{repo_id}")
|
||||||
|
|
||||||
|
if args.skip_version_branch:
|
||||||
|
return
|
||||||
|
|
||||||
|
version_branch = f"v{args.version}"
|
||||||
|
api.create_branch(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=HF_REPO_TYPE,
|
||||||
|
branch=version_branch,
|
||||||
|
revision="main",
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=HF_REPO_TYPE,
|
||||||
|
folder_path=output_dir,
|
||||||
|
revision=version_branch,
|
||||||
|
delete_patterns=delete_patterns,
|
||||||
|
commit_message=f"Upload ScatterMoE LoRA universal kernel {version_branch}",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Uploaded version branch: "
|
||||||
|
f"{HF_KERNEL_URL_PREFIX}/{repo_id}/tree/{version_branch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
try:
|
||||||
|
output_dir = build_package(args)
|
||||||
|
if args.upload:
|
||||||
|
upload_package(args, output_dir)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"error: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"Wrote ScatterMoE LoRA HF kernel package to: {output_dir}")
|
||||||
|
print(f"Shippable artifact: {output_dir / 'build' / BUILD_VARIANT / PACKAGE_NAME}")
|
||||||
|
if args.upload:
|
||||||
|
print(f'Load it with: get_kernel("{args.repo_id}", version={args.version})')
|
||||||
|
print(f"Uploaded as Hugging Face repo_type={HF_REPO_TYPE!r}.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print("Next step:")
|
||||||
|
print(" upload this universal Python/Triton kernel directly:")
|
||||||
|
print(
|
||||||
|
f" python3 {Path(__file__).as_posix()} "
|
||||||
|
f"--repo-id {args.repo_id} --force --upload"
|
||||||
|
)
|
||||||
|
if shutil.which("kernel-builder") is None:
|
||||||
|
print(" optional: install kernel-builder for full Nix-based builds:")
|
||||||
|
print(
|
||||||
|
" curl -fsSL "
|
||||||
|
"https://raw.githubusercontent.com/huggingface/kernels/main/install.sh "
|
||||||
|
"| bash"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(" optional: upload with kernel-builder:")
|
||||||
|
print(f" cd {output_dir}")
|
||||||
|
print(" kernel-builder build-and-upload")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64
|
multiple = 64
|
||||||
if self.cfg.pad_to_sequence_len:
|
if self.cfg.pad_to_sequence_len:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
|
|||||||
@@ -228,47 +228,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return training_args, trainer_kwargs
|
return training_args, trainer_kwargs
|
||||||
|
|
||||||
def build_collator(self, **kwargs):
|
|
||||||
"""Build a data collator for preference-tuning trainers.
|
|
||||||
|
|
||||||
Returns None for RL types that provide their own collator (e.g. GRPO,
|
|
||||||
KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO
|
|
||||||
returns an ``AxolotlDPODataCollatorWithPadding`` when
|
|
||||||
``pad_to_multiple_of`` is set, otherwise None (so the trainer
|
|
||||||
falls back to the TRL default).
|
|
||||||
"""
|
|
||||||
if self.cfg.rl not in (
|
|
||||||
RLType.DPO,
|
|
||||||
RLType.IPO,
|
|
||||||
RLType.ORPO,
|
|
||||||
RLType.SIMPO,
|
|
||||||
):
|
|
||||||
return None
|
|
||||||
|
|
||||||
pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None)
|
|
||||||
if not pad_to_multiple_of:
|
|
||||||
return None
|
|
||||||
|
|
||||||
from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of="
|
|
||||||
f"{pad_to_multiple_of}"
|
|
||||||
)
|
|
||||||
is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False)
|
|
||||||
return AxolotlDPODataCollatorWithPadding(
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
is_encoder_decoder=is_enc_dec,
|
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
||||||
|
|
||||||
if (data_collator := self.build_collator()) is not None:
|
|
||||||
trainer_kwargs["data_collator"] = data_collator
|
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -20,16 +20,8 @@ class DPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
training_args_kwargs = {}
|
training_args_kwargs = {}
|
||||||
if cfg.rl is RLType.DPO:
|
|
||||||
if cfg.dpo_loss_type is not None:
|
|
||||||
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
|
|
||||||
|
|
||||||
if cfg.dpo_loss_weights is not None:
|
|
||||||
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
|
|
||||||
|
|
||||||
if cfg.rl is RLType.IPO:
|
if cfg.rl is RLType.IPO:
|
||||||
training_args_kwargs["loss_type"] = ["ipo"]
|
training_args_kwargs["loss_type"] = ["ipo"]
|
||||||
|
|
||||||
# Label smoothing is not compatible with IPO
|
# Label smoothing is not compatible with IPO
|
||||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
|
|||||||
@@ -242,85 +242,6 @@ class ProducerConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _GroupShardedSampler:
|
|
||||||
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
|
|
||||||
|
|
||||||
``RepeatSampler`` yields ``num_generations`` consecutive copies of
|
|
||||||
each prompt, forming a GRPO group. For distributed training each
|
|
||||||
rank must see a disjoint slice of prompts (otherwise every rank
|
|
||||||
dogpiles on the first 1/world_size of the batch) while keeping each
|
|
||||||
group intact on a single rank so advantage normalization sees all
|
|
||||||
peer generations.
|
|
||||||
|
|
||||||
``accelerator.prepare(DataLoader)`` does not handle this correctly
|
|
||||||
for custom samplers with ``split_batches=False`` (the default): it
|
|
||||||
leaves the sampler alone and every rank replays identical indices.
|
|
||||||
This wrapper fixes that by consuming the inner sampler's full
|
|
||||||
output, chunking it into ``num_generations``-sized groups, and
|
|
||||||
round-robining whole groups across ranks.
|
|
||||||
|
|
||||||
Intended to be used ONLY when distributed training is active
|
|
||||||
(``num_replicas > 1``); for single-rank it is a no-op but still
|
|
||||||
correct.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inner: Any,
|
|
||||||
num_generations: int,
|
|
||||||
rank: int,
|
|
||||||
num_replicas: int,
|
|
||||||
):
|
|
||||||
if num_generations < 1:
|
|
||||||
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
|
|
||||||
if num_replicas < 1:
|
|
||||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
|
||||||
if not (0 <= rank < num_replicas):
|
|
||||||
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
|
||||||
self.inner = inner
|
|
||||||
self.num_generations = num_generations
|
|
||||||
self.rank = rank
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
all_indices = list(self.inner)
|
|
||||||
if len(all_indices) % self.num_generations != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"inner sampler yielded {len(all_indices)} indices, "
|
|
||||||
f"not a multiple of num_generations={self.num_generations}"
|
|
||||||
)
|
|
||||||
# Chunk the flat index sequence into groups of num_generations
|
|
||||||
# consecutive indices. ``RepeatSampler`` guarantees that each
|
|
||||||
# group contains num_generations copies of the same prompt id.
|
|
||||||
groups = [
|
|
||||||
all_indices[i : i + self.num_generations]
|
|
||||||
for i in range(0, len(all_indices), self.num_generations)
|
|
||||||
]
|
|
||||||
# Round-robin whole groups across ranks. Round-robin (vs.
|
|
||||||
# contiguous chunking) preserves approximate shuffled order on
|
|
||||||
# each rank even when the group count is small relative to the
|
|
||||||
# world size.
|
|
||||||
for group in groups[self.rank :: self.num_replicas]:
|
|
||||||
yield from group
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
try:
|
|
||||||
inner_len = len(self.inner)
|
|
||||||
except TypeError:
|
|
||||||
# Non-sized inner sampler — we can't know the per-rank
|
|
||||||
# length without materializing. Return 0 as a hint that the
|
|
||||||
# DataLoader should fall back to iteration.
|
|
||||||
return 0
|
|
||||||
total_groups = inner_len // self.num_generations
|
|
||||||
# Ceiling division for the trailing groups that don't divide
|
|
||||||
# evenly — extra groups go to the first ``total_groups %
|
|
||||||
# num_replicas`` ranks, matching the round-robin above.
|
|
||||||
my_groups = (
|
|
||||||
total_groups + self.num_replicas - self.rank - 1
|
|
||||||
) // self.num_replicas
|
|
||||||
return my_groups * self.num_generations
|
|
||||||
|
|
||||||
|
|
||||||
class DataProducer(ABC):
|
class DataProducer(ABC):
|
||||||
"""Abstract base class for online data producers.
|
"""Abstract base class for online data producers.
|
||||||
|
|
||||||
@@ -635,34 +556,6 @@ class GRPODataProducer(BaseDataProducer):
|
|||||||
seed=self._seed,
|
seed=self._seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Shard the sampler across distributed ranks so each rank sees
|
|
||||||
# a disjoint slice of prompts. ``RepeatSampler`` groups each
|
|
||||||
# prompt with ``num_generations`` consecutive copies — our
|
|
||||||
# wrapper round-robins WHOLE groups across ranks so all
|
|
||||||
# generations of a given prompt stay on the same rank (needed
|
|
||||||
# for GRPO advantage normalization within a group).
|
|
||||||
#
|
|
||||||
# Without this, ``accelerator.prepare(dl)`` with the default
|
|
||||||
# ``split_batches=False`` leaves the custom sampler alone, so
|
|
||||||
# every rank iterates the identical index sequence and the
|
|
||||||
# cluster dogpiles on the first 1/world_size of the prompts.
|
|
||||||
num_replicas = max(1, trainer.accelerator.num_processes)
|
|
||||||
if num_replicas > 1:
|
|
||||||
sampler = _GroupShardedSampler(
|
|
||||||
inner=sampler,
|
|
||||||
num_generations=self._num_generations,
|
|
||||||
rank=trainer.accelerator.process_index,
|
|
||||||
num_replicas=num_replicas,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"[RANK:%d] _GroupShardedSampler active "
|
|
||||||
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
|
|
||||||
trainer.accelerator.process_index,
|
|
||||||
num_replicas,
|
|
||||||
self._num_generations,
|
|
||||||
self._generation_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use identity collator (same as stock GRPOTrainer)
|
# Use identity collator (same as stock GRPOTrainer)
|
||||||
def _identity(x):
|
def _identity(x):
|
||||||
return x
|
return x
|
||||||
@@ -681,11 +574,12 @@ class GRPODataProducer(BaseDataProducer):
|
|||||||
rank=trainer.args.process_index,
|
rank=trainer.args.process_index,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# Skip accelerator.prepare — we're handling per-rank sharding
|
self._prompt_dl = trainer.accelerator.prepare(dl)
|
||||||
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
|
|
||||||
# otherwise try to wrap the DataLoader with its own sharding
|
# Don't let accelerator track this dataloader
|
||||||
# logic which does not understand our group structure.
|
acc_dls = trainer.accelerator._dataloaders
|
||||||
self._prompt_dl = dl
|
if self._prompt_dl in acc_dls:
|
||||||
|
acc_dls.remove(self._prompt_dl)
|
||||||
|
|
||||||
self._prompt_iter = iter(self._prompt_dl)
|
self._prompt_iter = iter(self._prompt_dl)
|
||||||
|
|
||||||
@@ -1209,22 +1103,11 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
||||||
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
||||||
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
||||||
|
|
||||||
This is the canonical sync trigger and runs in BOTH async and
|
|
||||||
synchronous modes from ``_prepare_inputs_with_data_producer`` /
|
|
||||||
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
|
|
||||||
patch is a parallel backup for non-data-producer paths (vanilla
|
|
||||||
GRPO without NeMo Gym), where the data producer is bypassed
|
|
||||||
entirely and TRL's stock generate-then-sync flow is used instead.
|
|
||||||
"""
|
"""
|
||||||
if not self.use_vllm:
|
if not (self.use_vllm and self.args.async_prefetch):
|
||||||
return
|
return
|
||||||
step = self.state.global_step
|
step = self.state.global_step
|
||||||
# Default to syncing every step when no interval is configured —
|
interval = self.args.vllm_sync_interval
|
||||||
# otherwise ``step % None`` would TypeError, and the previous
|
|
||||||
# behavior of crashing on the first sync was strictly worse than
|
|
||||||
# the standard "sync every optimizer step".
|
|
||||||
interval = self.args.vllm_sync_interval or 1
|
|
||||||
if step != self._last_synced_step and step % interval == 0:
|
if step != self._last_synced_step and step % interval == 0:
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||||
@@ -1319,42 +1202,13 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
|||||||
|
|
||||||
# Permanently replace vllm_generation.sync_weights with our custom
|
# Permanently replace vllm_generation.sync_weights with our custom
|
||||||
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
||||||
#
|
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
|
||||||
# The design has two modes that have to be threaded carefully:
|
# handles the sync with proper interval tracking.
|
||||||
#
|
|
||||||
# - Async prefetch ON: BG generation thread can't safely call
|
|
||||||
# sync_weights mid-rollout (it races with the trainer's optimizer
|
|
||||||
# step and can corrupt weights). We no-op the stock sync hook and
|
|
||||||
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
|
|
||||||
# optimizer step on the main thread.
|
|
||||||
#
|
|
||||||
# - Async prefetch OFF (synchronous mode): TRL's stock
|
|
||||||
# ``_generate_single_turn`` calls ``sync_weights`` once per step
|
|
||||||
# boundary. There's no BG thread to race with, and
|
|
||||||
# ``_maybe_sync_vllm_weights`` short-circuits with
|
|
||||||
# ``if not async_prefetch: return``, so we MUST wire the stock
|
|
||||||
# hook directly to our LoRA sync helper — otherwise nothing ever
|
|
||||||
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
|
|
||||||
# keeps serving the base model, every rollout in every group
|
|
||||||
# produces identical outputs, advantages are zero, optimizer
|
|
||||||
# step gets skipped, repeat).
|
|
||||||
if not getattr(self, "_patched_sync_weights", False):
|
if not getattr(self, "_patched_sync_weights", False):
|
||||||
if self.use_vllm and hasattr(self, "vllm_generation"):
|
if self.use_vllm and hasattr(self, "vllm_generation"):
|
||||||
if getattr(self.args, "vllm_lora_sync", False):
|
if getattr(self.args, "vllm_lora_sync", False):
|
||||||
if getattr(self.args, "async_prefetch", False):
|
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
|
||||||
# Async: drive sync from main thread via
|
self.vllm_generation.sync_weights = lambda: None
|
||||||
# _maybe_sync_vllm_weights instead.
|
|
||||||
self.vllm_generation.sync_weights = lambda: None
|
|
||||||
else:
|
|
||||||
# Sync mode: TRL's _generate_single_turn already
|
|
||||||
# calls sync_weights once per step boundary. Wire
|
|
||||||
# it directly to our LoRA filesystem sync helper.
|
|
||||||
sync_helper = self._sync_lora_adapter
|
|
||||||
|
|
||||||
def _lora_filesystem_sync():
|
|
||||||
sync_helper()
|
|
||||||
|
|
||||||
self.vllm_generation.sync_weights = _lora_filesystem_sync
|
|
||||||
self._patched_sync_weights = True
|
self._patched_sync_weights = True
|
||||||
else:
|
else:
|
||||||
from accelerate.utils import is_peft_model
|
from accelerate.utils import is_peft_model
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Hatchery/Tinker remote training integration for Axolotl.
|
|
||||||
|
|
||||||
Routes axolotl's preprocessed data to a remote training API (Tinker or
|
|
||||||
Hatchery) instead of running forward/backward locally. The remote
|
|
||||||
service handles model weights, LoRA adapters, and gradient updates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .args import HatcheryArgs, HatcheryConfig
|
|
||||||
from .plugin import HatcheryPlugin
|
|
||||||
|
|
||||||
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
|
|
||||||
|
|
||||||
# Usage:
|
|
||||||
# plugins:
|
|
||||||
# - axolotl.integrations.hatchery.HatcheryPlugin
|
|
||||||
#
|
|
||||||
# hatchery:
|
|
||||||
# backend: tinker # or "hatchery"
|
|
||||||
# lora_rank: 32
|
|
||||||
# loss_fn: cross_entropy # SFT
|
|
||||||
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
|
|
||||||
#
|
|
||||||
# learning_rate: 1e-4 # top-level, not under hatchery:
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Pydantic config schema for the Hatchery integration."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class HatcheryConfig(BaseModel):
|
|
||||||
"""Nested config under `hatchery:` in the axolotl YAML.
|
|
||||||
|
|
||||||
Only contains hatchery-specific settings. Standard training params
|
|
||||||
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
|
|
||||||
gradient_accumulation_steps) are read from axolotl's top-level config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Backend & connection
|
|
||||||
backend: Literal["tinker", "hatchery"] = "tinker"
|
|
||||||
base_url: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
project_id: Optional[str] = None
|
|
||||||
|
|
||||||
# LoRA config sent to remote
|
|
||||||
lora_rank: int = Field(32, ge=1, le=256)
|
|
||||||
train_attn: bool = True
|
|
||||||
train_mlp: bool = True
|
|
||||||
train_unembed: bool = True
|
|
||||||
|
|
||||||
# Loss function
|
|
||||||
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
|
|
||||||
"cross_entropy"
|
|
||||||
)
|
|
||||||
loss_fn_config: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
# Pipelining: submit next batch before awaiting previous result
|
|
||||||
pipeline: bool = True
|
|
||||||
|
|
||||||
# Sampling params (for RL flows)
|
|
||||||
max_sample_tokens: int = 256
|
|
||||||
sample_temperature: float = 1.0
|
|
||||||
num_samples: int = 4
|
|
||||||
|
|
||||||
# Reward functions (for RL) — list of fully qualified names
|
|
||||||
reward_funcs: Optional[list[str]] = None
|
|
||||||
|
|
||||||
# Checkpointing
|
|
||||||
save_steps: Optional[int] = None
|
|
||||||
save_name_prefix: str = "checkpoint"
|
|
||||||
|
|
||||||
# Timeout per future (seconds)
|
|
||||||
future_timeout: float = 600.0
|
|
||||||
|
|
||||||
|
|
||||||
class HatcheryArgs(BaseModel):
|
|
||||||
"""Top-level mixin that adds the nested `hatchery:` field."""
|
|
||||||
|
|
||||||
hatchery: Optional[HatcheryConfig] = None
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
|
|
||||||
|
|
||||||
Both Tinker and Hatchery expect the client to apply the causal LM shift:
|
|
||||||
|
|
||||||
Original tokens: [t0, t1, t2, ..., t_{L-1}]
|
|
||||||
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
|
|
||||||
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
|
|
||||||
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
|
|
||||||
|
|
||||||
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
|
|
||||||
"""Serialize a tensor to the TensorData wire dict."""
|
|
||||||
flat = t.detach().cpu().flatten()
|
|
||||||
dtype_map = {
|
|
||||||
torch.float32: "float32",
|
|
||||||
torch.float16: "float16",
|
|
||||||
torch.bfloat16: "bfloat16",
|
|
||||||
torch.int64: "int64",
|
|
||||||
torch.int32: "int32",
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"dtype": dtype_map.get(flat.dtype, "float32"),
|
|
||||||
"shape": list(t.shape),
|
|
||||||
"data": flat.tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _make_datum(
|
|
||||||
tokens: list[int],
|
|
||||||
loss_fn_inputs: dict[str, torch.Tensor],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
|
|
||||||
return {
|
|
||||||
"model_input": {
|
|
||||||
"chunks": [{"type": "encoded_text", "tokens": tokens}],
|
|
||||||
},
|
|
||||||
"loss_fn_inputs": {
|
|
||||||
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def datums_to_tinker(datums: list[dict[str, Any]]):
|
|
||||||
"""Wrap plain-dict datums into tinker.types.Datum objects.
|
|
||||||
|
|
||||||
Both the Tinker SDK and updated Hatchery client accept these.
|
|
||||||
"""
|
|
||||||
import tinker.types as tt
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for d in datums:
|
|
||||||
tokens = d["model_input"]["chunks"][0]["tokens"]
|
|
||||||
tinker_inputs = {}
|
|
||||||
for key, wire in d["loss_fn_inputs"].items():
|
|
||||||
tinker_inputs[key] = tt.TensorData(
|
|
||||||
data=wire["data"],
|
|
||||||
dtype=wire["dtype"],
|
|
||||||
shape=wire["shape"],
|
|
||||||
)
|
|
||||||
result.append(
|
|
||||||
tt.Datum(
|
|
||||||
model_input=tt.ModelInput.from_ints(tokens),
|
|
||||||
loss_fn_inputs=tinker_inputs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def batch_to_datums_sft(
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
|
|
||||||
batch_size = input_ids.size(0)
|
|
||||||
datums = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
ids = input_ids[i]
|
|
||||||
lbl = labels[i]
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
seq_len = int(attention_mask[i].sum().item())
|
|
||||||
ids = ids[:seq_len]
|
|
||||||
lbl = lbl[:seq_len]
|
|
||||||
|
|
||||||
model_tokens = ids[:-1].tolist()
|
|
||||||
shifted_labels = lbl[1:]
|
|
||||||
|
|
||||||
target_tokens = shifted_labels.clone()
|
|
||||||
weights = (shifted_labels != -100).float()
|
|
||||||
target_tokens[target_tokens == -100] = 0
|
|
||||||
|
|
||||||
datums.append(
|
|
||||||
_make_datum(
|
|
||||||
model_tokens,
|
|
||||||
{
|
|
||||||
"target_tokens": target_tokens,
|
|
||||||
"weights": weights,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return datums
|
|
||||||
|
|
||||||
|
|
||||||
def batch_to_datums_rl(
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
logprobs: torch.Tensor,
|
|
||||||
advantages: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
|
|
||||||
batch_size = input_ids.size(0)
|
|
||||||
datums = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
ids = input_ids[i]
|
|
||||||
lbl = labels[i]
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
seq_len = int(attention_mask[i].sum().item())
|
|
||||||
else:
|
|
||||||
seq_len = ids.size(0)
|
|
||||||
ids = ids[:seq_len]
|
|
||||||
lbl = lbl[:seq_len]
|
|
||||||
lp = logprobs[i, :seq_len]
|
|
||||||
adv = advantages[i, :seq_len]
|
|
||||||
|
|
||||||
model_tokens = ids[:-1].tolist()
|
|
||||||
|
|
||||||
target_tokens = lbl[1:].clone()
|
|
||||||
target_tokens[target_tokens == -100] = 0
|
|
||||||
|
|
||||||
datums.append(
|
|
||||||
_make_datum(
|
|
||||||
model_tokens,
|
|
||||||
{
|
|
||||||
"target_tokens": target_tokens,
|
|
||||||
"logprobs": lp[1:],
|
|
||||||
"advantages": adv[1:],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return datums
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
|
|
||||||
|
|
||||||
Creates a dataset with chat-formatted prompts that include
|
|
||||||
a hidden gold answer tag for the reward function.
|
|
||||||
|
|
||||||
Run:
|
|
||||||
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def extract_boxed(text: str) -> str:
|
|
||||||
match = re.search(r"\\boxed\{", text)
|
|
||||||
if not match:
|
|
||||||
return ""
|
|
||||||
start = match.end()
|
|
||||||
depth = 1
|
|
||||||
i = start
|
|
||||||
while i < len(text) and depth > 0:
|
|
||||||
if text[i] == "{":
|
|
||||||
depth += 1
|
|
||||||
elif text[i] == "}":
|
|
||||||
depth -= 1
|
|
||||||
i += 1
|
|
||||||
return text[start : i - 1] if depth == 0 else ""
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
|
|
||||||
|
|
||||||
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
|
|
||||||
level = os.environ.get("MATH_LEVEL", "Level 1")
|
|
||||||
filtered_rows = [x for x in ds if x["level"] == level]
|
|
||||||
print(f"{level} algebra: {len(filtered_rows)} problems")
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
for prob in filtered_rows:
|
|
||||||
gold = extract_boxed(prob["solution"])
|
|
||||||
if not gold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Format as chat prompt with hidden gold tag
|
|
||||||
prompt = (
|
|
||||||
f"Solve the following math problem. "
|
|
||||||
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
|
|
||||||
f"{prob['problem']}"
|
|
||||||
f"<|gold|>{gold}<|/gold|>"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize the prompt
|
|
||||||
text = tokenizer.apply_chat_template(
|
|
||||||
[{"role": "user", "content": prompt}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"input_ids": prompt_ids,
|
|
||||||
"labels": [-100] * len(prompt_ids),
|
|
||||||
"attention_mask": [1] * len(prompt_ids),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
out = Dataset.from_list(rows)
|
|
||||||
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
|
|
||||||
out.save_to_disk(out_dir)
|
|
||||||
print(f"Saved {len(out)} examples to {out_dir}")
|
|
||||||
if rows:
|
|
||||||
print(
|
|
||||||
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
|
|
||||||
f"-{max(len(r['input_ids']) for r in rows)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
|
|
||||||
#
|
|
||||||
# Prep:
|
|
||||||
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
|
||||||
#
|
|
||||||
# Run:
|
|
||||||
# export TINKER_API_KEY="your-key"
|
|
||||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
|
|
||||||
|
|
||||||
base_model: Qwen/Qwen3-8B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
|
||||||
|
|
||||||
hatchery:
|
|
||||||
backend: tinker
|
|
||||||
lora_rank: 16
|
|
||||||
loss_fn: importance_sampling
|
|
||||||
max_sample_tokens: 2048
|
|
||||||
sample_temperature: 0.7
|
|
||||||
num_samples: 4
|
|
||||||
pipeline: true
|
|
||||||
save_steps: 5
|
|
||||||
reward_funcs:
|
|
||||||
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: ./data/math_rl_level1
|
|
||||||
ds_type: arrow
|
|
||||||
type: completion
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
learning_rate: 5.0e-5
|
|
||||||
optimizer: adamw_torch
|
|
||||||
adam_beta1: 0.9
|
|
||||||
adam_beta2: 0.95
|
|
||||||
weight_decay: 0.01
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
|
|
||||||
max_steps: 10
|
|
||||||
num_epochs: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
output_dir: ./outputs/tinker-rl-math
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# export TINKER_API_KEY="your-key"
|
|
||||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
|
|
||||||
|
|
||||||
base_model: Qwen/Qwen3-8B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
|
||||||
|
|
||||||
hatchery:
|
|
||||||
backend: tinker
|
|
||||||
lora_rank: 16
|
|
||||||
loss_fn: cross_entropy
|
|
||||||
pipeline: true
|
|
||||||
save_steps: 10
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: TeichAI/kimi-k2-thinking-1000x
|
|
||||||
split: train[:50]
|
|
||||||
type: chat_template
|
|
||||||
chat_template: qwen3
|
|
||||||
split_thinking: true
|
|
||||||
|
|
||||||
chat_template: qwen3
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
learning_rate: 3.0e-4
|
|
||||||
optimizer: adamw_torch
|
|
||||||
adam_beta1: 0.9
|
|
||||||
adam_beta2: 0.95
|
|
||||||
weight_decay: 0.01
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
|
|
||||||
num_epochs: 1
|
|
||||||
max_steps: 20
|
|
||||||
micro_batch_size: 2
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
output_dir: ./outputs/tinker-sft
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import AutoConfig, PreTrainedModel, Trainer
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HatcheryPlugin(BasePlugin):
|
|
||||||
"""Plugin that replaces local training with remote API calls.
|
|
||||||
|
|
||||||
Activated by adding to the axolotl YAML:
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
|
||||||
|
|
||||||
hatchery:
|
|
||||||
backend: tinker # or "hatchery"
|
|
||||||
lora_rank: 32
|
|
||||||
loss_fn: cross_entropy
|
|
||||||
# ... see HatcheryConfig for full options
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self) -> str:
|
|
||||||
return "axolotl.integrations.hatchery.args.HatcheryArgs"
|
|
||||||
|
|
||||||
def register(self, cfg: dict):
|
|
||||||
"""Auto-set config values needed for remote training."""
|
|
||||||
if cfg.get("remove_unused_columns") is None:
|
|
||||||
cfg["remove_unused_columns"] = False
|
|
||||||
|
|
||||||
def pre_model_load(self, cfg: DictDefault):
|
|
||||||
"""Replace model loading with a tiny stub."""
|
|
||||||
hcfg = cfg.hatchery or {}
|
|
||||||
backend = (
|
|
||||||
hcfg.get("backend", "tinker")
|
|
||||||
if isinstance(hcfg, dict)
|
|
||||||
else getattr(hcfg, "backend", "tinker")
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
f"Hatchery plugin active: training dispatched to remote "
|
|
||||||
f"{backend} API. Skipping local model weight loading."
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.loaders import ModelLoader
|
|
||||||
|
|
||||||
def _stub_build_model(loader_self) -> bool:
|
|
||||||
base_model = loader_self.cfg.base_model
|
|
||||||
LOG.info(f"Skipping model weight loading for: {base_model}")
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
class _Stub(PreTrainedModel):
|
|
||||||
config_class = type(config)
|
|
||||||
_no_split_modules: list[str] = []
|
|
||||||
supports_gradient_checkpointing = False
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
super().__init__(cfg)
|
|
||||||
vocab_size = getattr(cfg, "vocab_size", 32000)
|
|
||||||
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.embed_tokens
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
loader_self.model = _Stub(config)
|
|
||||||
return True
|
|
||||||
|
|
||||||
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
|
|
||||||
|
|
||||||
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
|
||||||
"""Return the appropriate remote trainer class."""
|
|
||||||
hcfg = cfg.hatchery
|
|
||||||
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
|
|
||||||
|
|
||||||
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
|
|
||||||
from .rl_trainer import HatcheryRLTrainer
|
|
||||||
|
|
||||||
return HatcheryRLTrainer
|
|
||||||
|
|
||||||
from .trainer import HatcheryTrainer
|
|
||||||
|
|
||||||
return HatcheryTrainer
|
|
||||||
|
|
||||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
|
||||||
model._hatchery_remote = True
|
|
||||||
|
|
||||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
|
||||||
LOG.info(
|
|
||||||
"Hatchery: skipping local model save (weights are on remote API). "
|
|
||||||
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
|
||||||
"""Inject hatchery config + axolotl training params into the trainer."""
|
|
||||||
from .args import HatcheryConfig
|
|
||||||
from .rl_trainer import HatcheryRLTrainer
|
|
||||||
from .trainer import HatcheryTrainer
|
|
||||||
|
|
||||||
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
|
|
||||||
return
|
|
||||||
|
|
||||||
hcfg = cfg.hatchery
|
|
||||||
if isinstance(hcfg, dict):
|
|
||||||
hatchery_config = HatcheryConfig(**hcfg)
|
|
||||||
elif hcfg is None:
|
|
||||||
hatchery_config = HatcheryConfig()
|
|
||||||
else:
|
|
||||||
hatchery_config = hcfg
|
|
||||||
|
|
||||||
trainer.hatchery_args = hatchery_config
|
|
||||||
trainer._base_model_name = cfg.base_model
|
|
||||||
|
|
||||||
# Pull standard training params from axolotl config so they
|
|
||||||
# don't need to be duplicated under hatchery:
|
|
||||||
trainer._optim_params = {
|
|
||||||
"learning_rate": cfg.learning_rate
|
|
||||||
if cfg.learning_rate is not None
|
|
||||||
else 1e-4,
|
|
||||||
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
|
|
||||||
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
|
|
||||||
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
|
|
||||||
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
|
||||||
"grad_clip_norm": cfg.max_grad_norm
|
|
||||||
if cfg.max_grad_norm is not None
|
|
||||||
else 0.0,
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Math reward function for hendrycks_math GRPO training.
|
|
||||||
|
|
||||||
Uses math_verify for robust answer comparison. Falls back to
|
|
||||||
exact string match of \\boxed{} content only when math_verify
|
|
||||||
is unavailable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_boxed(text: str) -> str | None:
|
|
||||||
"""Extract \\boxed{...} answer handling nested braces."""
|
|
||||||
match = re.search(r"\\boxed\{", text)
|
|
||||||
if not match:
|
|
||||||
return None
|
|
||||||
start = match.end()
|
|
||||||
depth = 1
|
|
||||||
i = start
|
|
||||||
while i < len(text) and depth > 0:
|
|
||||||
if text[i] == "{":
|
|
||||||
depth += 1
|
|
||||||
elif text[i] == "}":
|
|
||||||
depth -= 1
|
|
||||||
i += 1
|
|
||||||
return text[start : i - 1] if depth == 0 else None
|
|
||||||
|
|
||||||
|
|
||||||
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
|
|
||||||
"""Score completions by checking if \\boxed{} answer matches the gold answer.
|
|
||||||
|
|
||||||
The gold answer is extracted from the prompt (appended as a hidden
|
|
||||||
tag by the dataset preprocessing). Format:
|
|
||||||
... <|gold|>ANSWER<|/gold|>
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
for prompt, completion in zip(prompts, completions, strict=True):
|
|
||||||
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
|
|
||||||
if not gold_match:
|
|
||||||
rewards.append(0.0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
gold_answer = gold_match.group(1).strip()
|
|
||||||
pred_answer = extract_boxed(completion)
|
|
||||||
|
|
||||||
if pred_answer is None:
|
|
||||||
rewards.append(0.0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
verified = None
|
|
||||||
try:
|
|
||||||
from math_verify import parse, verify
|
|
||||||
|
|
||||||
gold_parsed = parse(gold_answer)
|
|
||||||
pred_parsed = parse(pred_answer)
|
|
||||||
verified = verify(gold_parsed, pred_parsed)
|
|
||||||
except Exception:
|
|
||||||
LOG.debug(
|
|
||||||
"math_verify unavailable or failed, using string fallback",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if verified is not None:
|
|
||||||
rewards.append(1.0 if verified else 0.0)
|
|
||||||
elif pred_answer.strip() == gold_answer.strip():
|
|
||||||
rewards.append(1.0)
|
|
||||||
else:
|
|
||||||
rewards.append(0.0)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
@@ -1,409 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
|
|
||||||
|
|
||||||
Full RL loop per step:
|
|
||||||
1. Extract prompts from dataset batch
|
|
||||||
2. Sample N completions per prompt via remote SamplingClient
|
|
||||||
3. Score completions with local reward functions
|
|
||||||
4. Compute GRPO-style advantages (per-group normalization)
|
|
||||||
5. Send (prompt+completion, logprobs, advantages) as forward_backward
|
|
||||||
6. Optimizer step
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import Any, Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.trainer_utils import TrainOutput
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .args import HatcheryConfig
|
|
||||||
from .data import batch_to_datums_rl, datums_to_tinker
|
|
||||||
from .trainer import _create_training_client
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_reward_func(fqn: str) -> Callable:
|
|
||||||
"""Load a reward function from a fully qualified name like 'module.func'."""
|
|
||||||
module_path = ".".join(fqn.split(".")[:-1])
|
|
||||||
func_name = fqn.split(".")[-1]
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
func = getattr(mod, func_name)
|
|
||||||
if len(inspect.signature(func).parameters) < 2:
|
|
||||||
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class HatcheryRLTrainer(AxolotlTrainer):
|
|
||||||
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
|
|
||||||
|
|
||||||
hatchery_args: Optional[HatcheryConfig]
|
|
||||||
_base_model_name: Optional[str]
|
|
||||||
_training_client: Any
|
|
||||||
_reward_functions: list[Callable]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.hatchery_args = None
|
|
||||||
self._base_model_name = None
|
|
||||||
self._training_client = None
|
|
||||||
self._reward_functions = []
|
|
||||||
|
|
||||||
def _ensure_reward_functions(self):
|
|
||||||
if self._reward_functions:
|
|
||||||
return
|
|
||||||
args = self.hatchery_args
|
|
||||||
if not args or not args.reward_funcs:
|
|
||||||
raise ValueError(
|
|
||||||
"No reward functions configured. Set hatchery.reward_funcs "
|
|
||||||
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
|
|
||||||
)
|
|
||||||
for fqn in args.reward_funcs:
|
|
||||||
self._reward_functions.append(_load_reward_func(fqn))
|
|
||||||
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
|
|
||||||
|
|
||||||
def _get_training_client(self):
|
|
||||||
if self._training_client is not None:
|
|
||||||
return self._training_client
|
|
||||||
|
|
||||||
self._training_client = _create_training_client(
|
|
||||||
self.hatchery_args, self._base_model_name
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
f"Remote RL session created: backend={self.hatchery_args.backend}, "
|
|
||||||
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
|
|
||||||
)
|
|
||||||
return self._training_client
|
|
||||||
|
|
||||||
def _sample_completions(self, prompt_ids_list: list[list[int]]):
|
|
||||||
"""Sample completions for prompts via remote API."""
|
|
||||||
import tinker.types as tt
|
|
||||||
|
|
||||||
tc = self._get_training_client()
|
|
||||||
args = self.hatchery_args
|
|
||||||
assert args is not None # validated by _get_training_client
|
|
||||||
results = []
|
|
||||||
|
|
||||||
sc = tc.save_weights_and_get_sampling_client()
|
|
||||||
|
|
||||||
for prompt_ids in prompt_ids_list:
|
|
||||||
if hasattr(sc, "sampling_session_id"):
|
|
||||||
sample_result = sc.sample(
|
|
||||||
prompt_ids,
|
|
||||||
max_tokens=args.max_sample_tokens,
|
|
||||||
temperature=args.sample_temperature,
|
|
||||||
n=args.num_samples,
|
|
||||||
).result(timeout=args.future_timeout)
|
|
||||||
else:
|
|
||||||
mi = tt.ModelInput.from_ints(prompt_ids)
|
|
||||||
sp = tt.SamplingParams(
|
|
||||||
max_tokens=args.max_sample_tokens,
|
|
||||||
temperature=args.sample_temperature,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=-1,
|
|
||||||
)
|
|
||||||
sample_result = sc.sample(
|
|
||||||
prompt=mi,
|
|
||||||
num_samples=args.num_samples,
|
|
||||||
sampling_params=sp,
|
|
||||||
).result(timeout=args.future_timeout)
|
|
||||||
|
|
||||||
sequences = (
|
|
||||||
sample_result.sequences
|
|
||||||
if hasattr(sample_result, "sequences")
|
|
||||||
else sample_result.get("sequences", [])
|
|
||||||
)
|
|
||||||
for seq in sequences:
|
|
||||||
tokens = (
|
|
||||||
list(seq.tokens)
|
|
||||||
if hasattr(seq, "tokens")
|
|
||||||
else seq.get("tokens", [])
|
|
||||||
)
|
|
||||||
logprobs = (
|
|
||||||
list(seq.logprobs)
|
|
||||||
if hasattr(seq, "logprobs") and seq.logprobs
|
|
||||||
else seq.get("logprobs", [])
|
|
||||||
)
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"tokens": list(prompt_ids) + tokens,
|
|
||||||
"completion_tokens": tokens,
|
|
||||||
"logprobs": logprobs,
|
|
||||||
"prompt_len": len(prompt_ids),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _compute_rewards(
|
|
||||||
self, prompts: list[str], completions: list[str]
|
|
||||||
) -> list[float]:
|
|
||||||
total_rewards = [0.0] * len(completions)
|
|
||||||
for reward_fn in self._reward_functions:
|
|
||||||
rewards = reward_fn(prompts, completions)
|
|
||||||
for i, r in enumerate(rewards):
|
|
||||||
total_rewards[i] += r
|
|
||||||
return total_rewards
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
|
|
||||||
advantages = []
|
|
||||||
for i in range(0, len(rewards), group_size):
|
|
||||||
group = rewards[i : i + group_size]
|
|
||||||
mean = sum(group) / len(group)
|
|
||||||
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
|
|
||||||
std = var**0.5 if var > 1e-8 else 1.0
|
|
||||||
advantages.extend([(r - mean) / std for r in group])
|
|
||||||
return advantages
|
|
||||||
|
|
||||||
def _do_optim_step(self):
|
|
||||||
import tinker.types as tt
|
|
||||||
|
|
||||||
tc = self._get_training_client()
|
|
||||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
resume_from_checkpoint: Optional[str] = None,
|
|
||||||
trial: Any = None,
|
|
||||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> TrainOutput:
|
|
||||||
args = self.hatchery_args
|
|
||||||
if args is None:
|
|
||||||
raise RuntimeError("hatchery_args not configured")
|
|
||||||
|
|
||||||
self._ensure_reward_functions()
|
|
||||||
|
|
||||||
train_dataloader = self.get_train_dataloader()
|
|
||||||
num_train_epochs = int(self.args.num_train_epochs)
|
|
||||||
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Remote RL training: max_steps={max_steps}, "
|
|
||||||
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.state.max_steps = max_steps
|
|
||||||
self.state.num_train_epochs = num_train_epochs
|
|
||||||
self.state.is_local_process_zero = True
|
|
||||||
self.state.is_world_process_zero = True
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_begin(
|
|
||||||
self.args,
|
|
||||||
self.state,
|
|
||||||
self.control, # type: ignore[has-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = self.processing_class
|
|
||||||
global_step = 0
|
|
||||||
total_loss = 0.0
|
|
||||||
total_reward = 0.0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _epoch in range(num_train_epochs):
|
|
||||||
if global_step >= max_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
for batch in train_dataloader:
|
|
||||||
if global_step >= max_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_step_begin(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_ids_batch = batch["input_ids"]
|
|
||||||
# Full prompt text (with gold tag) for reward scoring
|
|
||||||
prompt_texts = tokenizer.batch_decode(
|
|
||||||
prompt_ids_batch, skip_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strip <|gold|>...<|/gold|> from token ids before
|
|
||||||
# sending to the model for sampling — the gold answer
|
|
||||||
# must only be visible to the local reward function.
|
|
||||||
sampling_prompts = []
|
|
||||||
for prompt_text in prompt_texts:
|
|
||||||
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
|
|
||||||
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
|
|
||||||
sampling_prompts.append(clean_ids)
|
|
||||||
|
|
||||||
# 1. Sample completions (without gold answer)
|
|
||||||
t0 = time.time()
|
|
||||||
samples = self._sample_completions(sampling_prompts)
|
|
||||||
t_sample = time.time() - t0
|
|
||||||
|
|
||||||
if not samples:
|
|
||||||
LOG.warning("No samples generated, skipping step")
|
|
||||||
continue
|
|
||||||
LOG.info(
|
|
||||||
f"Sampled {len(samples)} completions, "
|
|
||||||
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Decode and score
|
|
||||||
completion_texts = [
|
|
||||||
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
|
|
||||||
for s in samples
|
|
||||||
]
|
|
||||||
sample_prompts = []
|
|
||||||
for prompt_text in prompt_texts:
|
|
||||||
sample_prompts.extend([prompt_text] * args.num_samples)
|
|
||||||
|
|
||||||
rewards = self._compute_rewards(sample_prompts, completion_texts)
|
|
||||||
|
|
||||||
# 3. GRPO advantages
|
|
||||||
advantages_list = self._compute_advantages(
|
|
||||||
rewards, group_size=args.num_samples
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Build training data
|
|
||||||
all_datums = []
|
|
||||||
for i, sample in enumerate(samples):
|
|
||||||
full_tokens = sample["tokens"]
|
|
||||||
prompt_len = sample["prompt_len"]
|
|
||||||
seq_len = len(full_tokens)
|
|
||||||
|
|
||||||
input_ids = torch.tensor([full_tokens], dtype=torch.long)
|
|
||||||
labels = torch.full((1, seq_len), -100, dtype=torch.long)
|
|
||||||
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
|
|
||||||
|
|
||||||
logprobs_t = torch.zeros(1, seq_len)
|
|
||||||
if sample["logprobs"]:
|
|
||||||
lp = sample["logprobs"][: seq_len - prompt_len]
|
|
||||||
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
|
|
||||||
lp
|
|
||||||
)
|
|
||||||
|
|
||||||
adv_t = torch.zeros(1, seq_len)
|
|
||||||
adv_t[0, prompt_len:] = advantages_list[i]
|
|
||||||
|
|
||||||
all_datums.extend(
|
|
||||||
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Forward backward (one datum at a time for memory) + optim
|
|
||||||
t0 = time.time()
|
|
||||||
tc = self._get_training_client()
|
|
||||||
step_loss = 0.0
|
|
||||||
for datum in all_datums:
|
|
||||||
fb_future = tc.forward_backward(
|
|
||||||
datums_to_tinker([datum]),
|
|
||||||
loss_fn=args.loss_fn,
|
|
||||||
loss_fn_config=args.loss_fn_config,
|
|
||||||
)
|
|
||||||
fb_result = fb_future.result(timeout=args.future_timeout)
|
|
||||||
if hasattr(fb_result, "metrics"):
|
|
||||||
step_loss += float(
|
|
||||||
(fb_result.metrics or {}).get("loss:sum", 0.0)
|
|
||||||
)
|
|
||||||
elif isinstance(fb_result, dict):
|
|
||||||
step_loss += float(
|
|
||||||
fb_result.get("metrics", {}).get("loss:sum", 0.0)
|
|
||||||
)
|
|
||||||
optim_future = self._do_optim_step()
|
|
||||||
if not args.pipeline:
|
|
||||||
optim_future.result(timeout=args.future_timeout)
|
|
||||||
t_train = time.time() - t0
|
|
||||||
|
|
||||||
mean_reward = sum(rewards) / len(rewards)
|
|
||||||
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
|
|
||||||
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
|
|
||||||
global_step += 1
|
|
||||||
total_loss += step_loss
|
|
||||||
total_reward += mean_reward
|
|
||||||
self.state.global_step = global_step
|
|
||||||
|
|
||||||
log_interval = self.args.logging_steps or 1
|
|
||||||
if global_step % log_interval == 0:
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
LOG.info(
|
|
||||||
f"[step {global_step}/{max_steps}] "
|
|
||||||
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
|
|
||||||
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
|
|
||||||
f"sample={t_sample:.1f}s train={t_train:.1f}s "
|
|
||||||
f"{elapsed / global_step:.1f}s/step"
|
|
||||||
)
|
|
||||||
self.log(
|
|
||||||
{
|
|
||||||
"loss": step_loss,
|
|
||||||
"reward": mean_reward,
|
|
||||||
"accuracy": accuracy,
|
|
||||||
"mean_abs_advantage": mean_adv,
|
|
||||||
"learning_rate": self._optim_params["learning_rate"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.save_steps and global_step % args.save_steps == 0:
|
|
||||||
self._save_remote_checkpoint(global_step)
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_step_end(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
if self.control.should_training_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.control.should_training_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
if global_step > 0:
|
|
||||||
self._save_remote_checkpoint(global_step, name="final")
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
avg_loss = total_loss / max(global_step, 1)
|
|
||||||
avg_reward = total_reward / max(global_step, 1)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
|
|
||||||
f"avg_reward={avg_reward:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_end(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainOutput(
|
|
||||||
global_step=global_step,
|
|
||||||
training_loss=avg_loss,
|
|
||||||
metrics={
|
|
||||||
"train_loss": avg_loss,
|
|
||||||
"train_reward": avg_reward,
|
|
||||||
"train_runtime": elapsed,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
|
||||||
tc = self._get_training_client()
|
|
||||||
args = self.hatchery_args
|
|
||||||
assert args is not None # validated by _get_training_client
|
|
||||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
|
||||||
try:
|
|
||||||
future = tc.save_state(ckpt_name)
|
|
||||||
future.result(timeout=args.future_timeout)
|
|
||||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
|
||||||
if name == "final":
|
|
||||||
raise
|
|
||||||
|
|
||||||
def save_model(self, output_dir=None, _internal_call=False):
|
|
||||||
self._save_remote_checkpoint(
|
|
||||||
step=self.state.global_step,
|
|
||||||
name=output_dir or "hf-save",
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
|
|
||||||
)
|
|
||||||
@@ -1,327 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Remote trainer that dispatches to Tinker or Hatchery API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.trainer_utils import TrainOutput
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .args import HatcheryConfig
|
|
||||||
from .data import batch_to_datums_sft, datums_to_tinker
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_loss(result) -> float:
|
|
||||||
"""Extract loss:sum from a forward_backward result.
|
|
||||||
|
|
||||||
Tinker's cross_entropy (and other losses) return the SUM of per-token
|
|
||||||
losses, not the mean. This is by design — it lets users control
|
|
||||||
normalization via the weights tensor. The trainer logs this raw sum;
|
|
||||||
users who want per-token loss should divide by number of active tokens.
|
|
||||||
"""
|
|
||||||
if hasattr(result, "metrics"):
|
|
||||||
metrics = result.metrics or {}
|
|
||||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
|
||||||
if isinstance(result, dict):
|
|
||||||
metrics = result.get("metrics", {})
|
|
||||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def _create_training_client(args: HatcheryConfig, base_model: str):
|
|
||||||
"""Create a training client for either Tinker or Hatchery backend."""
|
|
||||||
if args.backend == "tinker":
|
|
||||||
import tinker
|
|
||||||
|
|
||||||
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Tinker API key required. Set `hatchery.api_key` in config "
|
|
||||||
"or TINKER_API_KEY env var."
|
|
||||||
)
|
|
||||||
os.environ["TINKER_API_KEY"] = api_key
|
|
||||||
|
|
||||||
service = tinker.ServiceClient(project_id=args.project_id)
|
|
||||||
return service.create_lora_training_client(
|
|
||||||
base_model=base_model,
|
|
||||||
rank=args.lora_rank,
|
|
||||||
train_mlp=args.train_mlp,
|
|
||||||
train_attn=args.train_attn,
|
|
||||||
train_unembed=args.train_unembed,
|
|
||||||
)
|
|
||||||
|
|
||||||
from hatchery.core.client import HatcheryClient
|
|
||||||
|
|
||||||
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
|
|
||||||
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
|
|
||||||
|
|
||||||
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
|
|
||||||
return client.create_lora_training_client(
|
|
||||||
base_model=base_model,
|
|
||||||
rank=args.lora_rank,
|
|
||||||
train_attn=args.train_attn,
|
|
||||||
train_mlp=args.train_mlp,
|
|
||||||
train_unembed=args.train_unembed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HatcheryTrainer(AxolotlTrainer):
|
|
||||||
"""Trainer that sends preprocessed batches to a remote training API.
|
|
||||||
|
|
||||||
Replaces local forward/backward with remote API calls to Tinker or
|
|
||||||
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
|
|
||||||
chat templates, packing, etc.) but offloads compute to remote GPUs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
hatchery_args: Optional[HatcheryConfig]
|
|
||||||
_base_model_name: Optional[str]
|
|
||||||
_training_client: Any
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.hatchery_args = None
|
|
||||||
self._base_model_name = None
|
|
||||||
self._training_client = None
|
|
||||||
|
|
||||||
def _get_training_client(self):
|
|
||||||
"""Lazily create the remote training session."""
|
|
||||||
if self._training_client is not None:
|
|
||||||
return self._training_client
|
|
||||||
|
|
||||||
args = self.hatchery_args
|
|
||||||
if args is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"HatcheryTrainer.hatchery_args not set. "
|
|
||||||
"Ensure the HatcheryPlugin is registered."
|
|
||||||
)
|
|
||||||
|
|
||||||
base_model = self._base_model_name
|
|
||||||
if not base_model:
|
|
||||||
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
|
|
||||||
|
|
||||||
self._training_client = _create_training_client(args, base_model)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Remote training session created: backend={args.backend}, "
|
|
||||||
f"model={base_model}, rank={args.lora_rank}"
|
|
||||||
)
|
|
||||||
return self._training_client
|
|
||||||
|
|
||||||
def _send_batch(self, batch: dict[str, torch.Tensor]):
|
|
||||||
"""Convert batch to datums and send forward_backward to remote.
|
|
||||||
|
|
||||||
Returns (future, n_active_tokens) where n_active_tokens counts
|
|
||||||
the completion tokens in this batch (for loss normalization).
|
|
||||||
"""
|
|
||||||
input_ids = batch["input_ids"]
|
|
||||||
labels = batch["labels"]
|
|
||||||
attention_mask = batch.get("attention_mask")
|
|
||||||
|
|
||||||
n_active = int((labels[:, 1:] != -100).sum().item())
|
|
||||||
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
|
|
||||||
|
|
||||||
tc = self._get_training_client()
|
|
||||||
args = self.hatchery_args
|
|
||||||
assert args is not None # validated by _get_training_client
|
|
||||||
send_datums = datums_to_tinker(datums)
|
|
||||||
|
|
||||||
future = tc.forward_backward(
|
|
||||||
send_datums,
|
|
||||||
loss_fn=args.loss_fn,
|
|
||||||
loss_fn_config=args.loss_fn_config,
|
|
||||||
)
|
|
||||||
return future, n_active
|
|
||||||
|
|
||||||
def _do_optim_step(self):
|
|
||||||
"""Send optimizer step to remote using axolotl's training params."""
|
|
||||||
import tinker.types as tt
|
|
||||||
|
|
||||||
tc = self._get_training_client()
|
|
||||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
resume_from_checkpoint: Optional[str] = None,
|
|
||||||
trial: Any = None,
|
|
||||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> TrainOutput:
|
|
||||||
"""Main training loop — sends batches to remote API."""
|
|
||||||
args = self.hatchery_args
|
|
||||||
if args is None:
|
|
||||||
raise RuntimeError("hatchery_args not configured")
|
|
||||||
|
|
||||||
train_dataloader = self.get_train_dataloader()
|
|
||||||
num_batches = len(train_dataloader)
|
|
||||||
|
|
||||||
grad_accum = self.args.gradient_accumulation_steps
|
|
||||||
num_train_epochs = int(self.args.num_train_epochs)
|
|
||||||
steps_per_epoch = max(num_batches // grad_accum, 1)
|
|
||||||
max_steps = (
|
|
||||||
self.args.max_steps
|
|
||||||
if self.args.max_steps > 0
|
|
||||||
else steps_per_epoch * num_train_epochs
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Remote training: {num_batches} batches/epoch, "
|
|
||||||
f"{grad_accum} grad_accum, {max_steps} max steps, "
|
|
||||||
f"{num_train_epochs} epochs"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.state.max_steps = max_steps
|
|
||||||
self.state.num_train_epochs = num_train_epochs
|
|
||||||
self.state.is_local_process_zero = True
|
|
||||||
self.state.is_world_process_zero = True
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_begin(
|
|
||||||
self.args,
|
|
||||||
self.state,
|
|
||||||
self.control, # type: ignore[has-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
global_step = 0
|
|
||||||
total_loss = 0.0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _epoch in range(num_train_epochs):
|
|
||||||
if global_step >= max_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_begin(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
pending_fb_futures = []
|
|
||||||
accum_count = 0
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dataloader):
|
|
||||||
if global_step >= max_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_step_begin(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_future, n_active = self._send_batch(batch)
|
|
||||||
pending_fb_futures.append((fb_future, n_active))
|
|
||||||
accum_count += 1
|
|
||||||
|
|
||||||
if accum_count >= grad_accum:
|
|
||||||
step_loss_sum = 0.0
|
|
||||||
step_active = 0
|
|
||||||
for fut, n_act in pending_fb_futures:
|
|
||||||
result = fut.result(timeout=args.future_timeout)
|
|
||||||
step_loss_sum += _extract_loss(result)
|
|
||||||
step_active += n_act
|
|
||||||
|
|
||||||
optim_future = self._do_optim_step()
|
|
||||||
if not args.pipeline:
|
|
||||||
optim_future.result(timeout=args.future_timeout)
|
|
||||||
|
|
||||||
step_loss = (
|
|
||||||
step_loss_sum / step_active
|
|
||||||
if step_active > 0
|
|
||||||
else step_loss_sum
|
|
||||||
)
|
|
||||||
|
|
||||||
global_step += 1
|
|
||||||
total_loss += step_loss
|
|
||||||
self.state.global_step = global_step
|
|
||||||
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
|
|
||||||
|
|
||||||
log_interval = self.args.logging_steps or 1
|
|
||||||
if global_step % log_interval == 0:
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
avg_loss = total_loss / global_step
|
|
||||||
LOG.info(
|
|
||||||
f"[step {global_step}/{max_steps}] "
|
|
||||||
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
|
|
||||||
f"active={step_active} "
|
|
||||||
f"{elapsed / global_step:.2f}s/step"
|
|
||||||
)
|
|
||||||
self.log(
|
|
||||||
{
|
|
||||||
"loss": step_loss,
|
|
||||||
"learning_rate": self._optim_params["learning_rate"],
|
|
||||||
"epoch": self.state.epoch,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.save_steps and global_step % args.save_steps == 0:
|
|
||||||
self._save_remote_checkpoint(global_step)
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_step_end(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
pending_fb_futures = []
|
|
||||||
accum_count = 0
|
|
||||||
|
|
||||||
if self.control.should_training_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_end(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
if self.control.should_training_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
if global_step > 0:
|
|
||||||
self._save_remote_checkpoint(global_step, name="final")
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
avg_loss = total_loss / max(global_step, 1)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
|
|
||||||
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_end(
|
|
||||||
self.args, self.state, self.control
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainOutput(
|
|
||||||
global_step=global_step,
|
|
||||||
training_loss=avg_loss,
|
|
||||||
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
|
||||||
"""Save a checkpoint on the remote service."""
|
|
||||||
tc = self._get_training_client()
|
|
||||||
args = self.hatchery_args
|
|
||||||
assert args is not None # validated by _get_training_client
|
|
||||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
|
||||||
try:
|
|
||||||
future = tc.save_state(ckpt_name)
|
|
||||||
future.result(timeout=args.future_timeout)
|
|
||||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
|
||||||
except Exception:
|
|
||||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
|
||||||
if name == "final":
|
|
||||||
raise
|
|
||||||
|
|
||||||
def save_model(self, output_dir=None, _internal_call=False):
|
|
||||||
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
|
|
||||||
self._save_remote_checkpoint(
|
|
||||||
step=self.state.global_step,
|
|
||||||
name=output_dir or "hf-save",
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"HatcheryTrainer uses remote API; compute_loss should not be called."
|
|
||||||
)
|
|
||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # recommended to reduce vram
|
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -110,36 +110,11 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
item["agent_ref"] = full_item["agent_ref"]
|
item["agent_ref"] = full_item["agent_ref"]
|
||||||
dataset_items.append(item)
|
dataset_items.append(item)
|
||||||
|
|
||||||
# NOTE: do NOT re-expand by num_generations here.
|
# Expand by num_generations (agent produces one rollout per call)
|
||||||
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
|
expanded_items = []
|
||||||
# yields ``num_generations`` consecutive copies of each unique
|
for item in dataset_items:
|
||||||
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
|
for _ in range(self._num_generations):
|
||||||
# num_generations)`` items — one entry per rollout. Expanding
|
expanded_items.append(item)
|
||||||
# again here would fire ``num_generations^2`` rollouts per
|
|
||||||
# prompt per rank and make every step dogpile on a handful of
|
|
||||||
# tasks.
|
|
||||||
expanded_items = dataset_items
|
|
||||||
|
|
||||||
# Diagnostic: log what this rank is about to fire.
|
|
||||||
try:
|
|
||||||
import collections
|
|
||||||
|
|
||||||
iid_counts: collections.Counter[str | None] = collections.Counter()
|
|
||||||
for it in dataset_items:
|
|
||||||
iid_counts[
|
|
||||||
(it.get("responses_create_params", {}).get("metadata") or {}).get(
|
|
||||||
"instance_id"
|
|
||||||
)
|
|
||||||
] += 1
|
|
||||||
LOG.info(
|
|
||||||
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
|
|
||||||
trainer.accelerator.process_index,
|
|
||||||
len(dataset_items),
|
|
||||||
len(iid_counts),
|
|
||||||
list(iid_counts.most_common(5)),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Call NeMo Gym agents
|
# Call NeMo Gym agents
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
@@ -165,7 +140,6 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
logprobs_list = []
|
logprobs_list = []
|
||||||
rewards_list = []
|
rewards_list = []
|
||||||
|
|
||||||
num_turns_list: list[int] = []
|
|
||||||
for resp in responses:
|
for resp in responses:
|
||||||
parsed = _parse_agent_response(resp, eos_token_id)
|
parsed = _parse_agent_response(resp, eos_token_id)
|
||||||
prompt_ids_list.append(parsed["prompt_ids"])
|
prompt_ids_list.append(parsed["prompt_ids"])
|
||||||
@@ -173,7 +147,6 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
env_mask_list.append(parsed["env_mask"])
|
env_mask_list.append(parsed["env_mask"])
|
||||||
logprobs_list.append(parsed["logprobs"])
|
logprobs_list.append(parsed["logprobs"])
|
||||||
rewards_list.append(parsed["reward"])
|
rewards_list.append(parsed["reward"])
|
||||||
num_turns_list.append(parsed.get("num_turns", 0))
|
|
||||||
|
|
||||||
# Pad to tensors
|
# Pad to tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@@ -206,48 +179,22 @@ class NemoGymDataProducer(GRPODataProducer):
|
|||||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||||
|
|
||||||
# Inject per-rollout reward + num_turns into each input. Since
|
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||||
# ``RepeatSampler`` already yields ``num_generations`` copies of
|
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||||
# each prompt, ``inputs`` has ONE entry per rollout (matching
|
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||||
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
|
|
||||||
# GRPO advantage normalization is the trainer's job downstream.
|
|
||||||
assert len(inputs) == len(rewards_list), (
|
|
||||||
f"rewards/inputs length mismatch: "
|
|
||||||
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
|
|
||||||
)
|
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
inp["env_reward"] = rewards_list[i]
|
# Each input gets rewards for its num_generations rollouts
|
||||||
inp["num_turns"] = num_turns_list[i]
|
start = i * self._num_generations
|
||||||
|
end = start + self._num_generations
|
||||||
|
inp["env_reward"] = rewards_list[start:end]
|
||||||
|
|
||||||
# One expanded_input per rollout (already correct count because
|
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||||
# inputs has num_generations copies baked in by the sampler).
|
expanded_inputs = []
|
||||||
expanded_inputs = [dict(inp) for inp in inputs]
|
for inp in inputs:
|
||||||
|
for g in range(self._num_generations):
|
||||||
# Log rollout-level stats to wandb from rank 0. These are the
|
expanded_inp = dict(inp)
|
||||||
# true agent-side metrics (not the tokenized TRL view) — so
|
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||||
# num_turns reflects how many /run iterations each rollout
|
expanded_inputs.append(expanded_inp)
|
||||||
# actually took before finishing or hitting max_turns.
|
|
||||||
if is_main and num_turns_list:
|
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
if wandb.run is not None:
|
|
||||||
import statistics as _stats
|
|
||||||
|
|
||||||
nonzero = sum(1 for r in rewards_list if r > 0)
|
|
||||||
log_payload = {
|
|
||||||
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
|
|
||||||
"rollout/num_turns/min": float(min(num_turns_list)),
|
|
||||||
"rollout/num_turns/max": float(max(num_turns_list)),
|
|
||||||
"rollout/reward/mean": float(_stats.mean(rewards_list)),
|
|
||||||
"rollout/reward/nonzero_frac": (
|
|
||||||
nonzero / len(rewards_list) if rewards_list else 0.0
|
|
||||||
),
|
|
||||||
"rollout/n_samples": float(len(rewards_list)),
|
|
||||||
}
|
|
||||||
wandb.log(log_payload, commit=False)
|
|
||||||
except Exception as exc: # never let metric logging break training
|
|
||||||
LOG.warning("rollout wandb log failed: %s", exc)
|
|
||||||
|
|
||||||
# Decode completions for reward functions
|
# Decode completions for reward functions
|
||||||
completions = trainer.processing_class.batch_decode(
|
completions = trainer.processing_class.batch_decode(
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ Supports two modes:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
@@ -31,107 +30,6 @@ if TYPE_CHECKING:
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ---- vLLM weight-sync transport probe ------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VLLMWeightSyncCapabilities:
|
|
||||||
"""What weight-sync routes a vLLM server actually exposes.
|
|
||||||
|
|
||||||
Discovered once at ``pre_model_load`` time by fetching the server's
|
|
||||||
``/openapi.json``. Drives the transport-selection table below.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nccl: bool = False # /init_communicator/ + /update_named_param/
|
|
||||||
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
|
|
||||||
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
|
|
||||||
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
|
|
||||||
probed: bool = False
|
|
||||||
probe_error: str | None = None
|
|
||||||
routes: list[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def any_full_param_sync(self) -> bool:
|
|
||||||
"""True if at least one transport can push full-model weights."""
|
|
||||||
return self.nccl or self.http_full
|
|
||||||
|
|
||||||
@property
|
|
||||||
def any_lora_sync(self) -> bool:
|
|
||||||
"""True if at least one transport can push LoRA adapters."""
|
|
||||||
return self.lora_filesystem or self.lora_axolotl or self.nccl
|
|
||||||
|
|
||||||
|
|
||||||
def probe_vllm_weight_sync(
|
|
||||||
base_url: str, timeout: float = 5.0
|
|
||||||
) -> VLLMWeightSyncCapabilities:
|
|
||||||
"""Detect which weight-sync routes the configured vLLM server exposes.
|
|
||||||
|
|
||||||
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
|
|
||||||
we care about is mounted as a POST route there. Falls back to all-False
|
|
||||||
on any error so the caller can still decide what to do (typically: raise
|
|
||||||
a clear error rather than silently no-op).
|
|
||||||
"""
|
|
||||||
import requests
|
|
||||||
|
|
||||||
caps = VLLMWeightSyncCapabilities()
|
|
||||||
try:
|
|
||||||
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
|
|
||||||
r.raise_for_status()
|
|
||||||
spec = r.json()
|
|
||||||
routes = sorted((spec.get("paths") or {}).keys())
|
|
||||||
caps.routes = routes
|
|
||||||
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
|
|
||||||
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
|
|
||||||
caps.lora_axolotl = "/set_lora_adapter/" in routes
|
|
||||||
caps.http_full = "/http_update_weights/" in routes
|
|
||||||
caps.probed = True
|
|
||||||
except Exception as exc:
|
|
||||||
caps.probe_error = f"{type(exc).__name__}: {exc}"
|
|
||||||
LOG.warning(
|
|
||||||
"NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. "
|
|
||||||
"Will fall back to LoRA-only behavior.",
|
|
||||||
base_url,
|
|
||||||
caps.probe_error,
|
|
||||||
)
|
|
||||||
return caps
|
|
||||||
|
|
||||||
|
|
||||||
def select_weight_sync_transport(
|
|
||||||
caps: VLLMWeightSyncCapabilities,
|
|
||||||
*,
|
|
||||||
has_lora: bool,
|
|
||||||
vllm_lora_sync_pref: bool,
|
|
||||||
) -> str:
|
|
||||||
"""Pick the right transport for a (server caps, model type) combo.
|
|
||||||
|
|
||||||
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
|
|
||||||
``"none"``. The caller decides what to do with ``"none"`` (typically:
|
|
||||||
raise an error explaining the misconfiguration).
|
|
||||||
|
|
||||||
Selection table:
|
|
||||||
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
|
|
||||||
LoRA model + lora endpoint → lora_filesystem
|
|
||||||
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
|
|
||||||
Full model + nccl endpoint → nccl
|
|
||||||
Full model + http endpoint → http_full
|
|
||||||
anything else → none
|
|
||||||
"""
|
|
||||||
if has_lora:
|
|
||||||
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
|
|
||||||
return "lora_filesystem"
|
|
||||||
if caps.lora_filesystem or caps.lora_axolotl:
|
|
||||||
return "lora_filesystem"
|
|
||||||
if caps.nccl:
|
|
||||||
return "nccl"
|
|
||||||
return "none"
|
|
||||||
# Full-parameter model
|
|
||||||
if caps.nccl:
|
|
||||||
return "nccl"
|
|
||||||
if caps.http_full:
|
|
||||||
return "http_full"
|
|
||||||
return "none"
|
|
||||||
|
|
||||||
|
|
||||||
class NemoGymPlugin(BasePlugin):
|
class NemoGymPlugin(BasePlugin):
|
||||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||||
|
|
||||||
@@ -152,69 +50,37 @@ class NemoGymPlugin(BasePlugin):
|
|||||||
self._reward_fn = None
|
self._reward_fn = None
|
||||||
self._dataset_lookup = None
|
self._dataset_lookup = None
|
||||||
self._agent_servers = {}
|
self._agent_servers = {}
|
||||||
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
|
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
|
"""Apply monkeypatches before trainer creation."""
|
||||||
|
|
||||||
Replaces the previous unconditional ``init_communicator`` monkey-patch
|
|
||||||
with a probe of the configured vLLM server's ``/openapi.json``. We only
|
|
||||||
bypass NCCL init when the server we're talking to actually lacks the
|
|
||||||
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
|
|
||||||
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
|
|
||||||
standard TRL flow alone so full-finetune training can sync weights.
|
|
||||||
"""
|
|
||||||
if not cfg.nemo_gym_enabled:
|
if not cfg.nemo_gym_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Always skip NCCL communicator init in NeMo Gym mode.
|
||||||
|
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
||||||
|
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
||||||
trl_cfg = getattr(cfg, "trl", None)
|
trl_cfg = getattr(cfg, "trl", None)
|
||||||
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
|
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
||||||
return
|
|
||||||
|
|
||||||
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
|
|
||||||
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
|
|
||||||
base_url = f"http://{host}:{port}"
|
|
||||||
self._vllm_caps = probe_vllm_weight_sync(base_url)
|
|
||||||
|
|
||||||
if self._vllm_caps.probed:
|
|
||||||
LOG.info(
|
|
||||||
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
|
|
||||||
"lora_axolotl=%s http_full=%s",
|
|
||||||
base_url,
|
|
||||||
self._vllm_caps.nccl,
|
|
||||||
self._vllm_caps.lora_filesystem,
|
|
||||||
self._vllm_caps.lora_axolotl,
|
|
||||||
self._vllm_caps.http_full,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
|
|
||||||
# available we leave VLLMClient.init_communicator alone so the
|
|
||||||
# standard TRL sync flow can run for full-parameter training.
|
|
||||||
if not self._vllm_caps.nccl:
|
|
||||||
self._patch_skip_nccl_init()
|
self._patch_skip_nccl_init()
|
||||||
|
|
||||||
def _patch_skip_nccl_init(self):
|
def _patch_skip_nccl_init(self):
|
||||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||||
|
|
||||||
Only called when the configured vLLM server doesn't expose
|
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
||||||
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
|
serve script). The NCCL communicator is not needed and fails with both
|
||||||
TRL's standard ``init_communicator`` would 404 inside trainer
|
vLLM V1 engine and standard OpenAI server mode.
|
||||||
construction; we no-op it so the LoRA filesystem path can install
|
|
||||||
its own sync in ``post_trainer_create``.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from trl.generation.vllm_client import VLLMClient
|
from trl.generation.vllm_client import VLLMClient
|
||||||
|
|
||||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||||
"Skipping NCCL init_communicator (server has no /init_communicator/)"
|
"Skipping NCCL init_communicator (LoRA sync mode)"
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
|
|
||||||
)
|
)
|
||||||
|
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||||
|
|
||||||
@@ -368,80 +234,30 @@ class NemoGymPlugin(BasePlugin):
|
|||||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||||
|
|
||||||
# Pick a weight-sync transport based on what the configured vLLM
|
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
||||||
# server actually exposes (see ``pre_model_load`` probe) and what
|
# - Install LoRA sync (when vllm_lora_sync=True)
|
||||||
# kind of model we're training. The selection table is documented
|
# - Or no-op sync_weights (when using standard vLLM server)
|
||||||
# in ``select_weight_sync_transport``.
|
|
||||||
trl_cfg = getattr(cfg, "trl", None)
|
trl_cfg = getattr(cfg, "trl", None)
|
||||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||||
vllm_gen = trainer.vllm_generation
|
vllm_gen = trainer.vllm_generation
|
||||||
adapter = getattr(cfg, "adapter", None)
|
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
||||||
has_lora = adapter in ("lora", "qlora")
|
|
||||||
vllm_lora_sync_pref = bool(
|
|
||||||
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
|
|
||||||
)
|
|
||||||
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
|
|
||||||
transport = select_weight_sync_transport(
|
|
||||||
caps,
|
|
||||||
has_lora=has_lora,
|
|
||||||
vllm_lora_sync_pref=vllm_lora_sync_pref,
|
|
||||||
)
|
|
||||||
|
|
||||||
if transport == "lora_filesystem":
|
|
||||||
self._setup_lora_sync(trainer)
|
self._setup_lora_sync(trainer)
|
||||||
|
# Verify the vLLM server supports runtime LoRA loading
|
||||||
self._check_lora_endpoint(vllm_gen)
|
self._check_lora_endpoint(vllm_gen)
|
||||||
LOG.info("NeMo Gym weight sync: LoRA filesystem")
|
else:
|
||||||
elif transport == "nccl":
|
# No NCCL, no LoRA sync — skip all weight sync paths
|
||||||
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
|
vllm_gen.sync_weights = lambda: LOG.debug(
|
||||||
# alone (pre_model_load only patched it when the probe found no
|
"Weight sync skipped (NeMo Gym mode)"
|
||||||
# NCCL route) so the trainer's normal weight-sync flow runs.
|
|
||||||
LOG.info(
|
|
||||||
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
|
|
||||||
)
|
)
|
||||||
elif transport == "http_full":
|
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
||||||
# Full-parameter HTTP sync — implementation lands in step 3.
|
"Weight sync skipped (NeMo Gym mode)"
|
||||||
# For now, fail loudly so users know the path is detected but
|
|
||||||
# not yet wired up, instead of silently no-oping like before.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
|
|
||||||
"but the client-side sync helper is not yet implemented "
|
|
||||||
"(planned). Use `adapter: lora|qlora` for now, or use a "
|
|
||||||
"vLLM serve module that exposes /init_communicator/ for "
|
|
||||||
"NCCL sync."
|
|
||||||
)
|
)
|
||||||
else: # transport == "none"
|
# Also patch the async trainer's internal sync method
|
||||||
# No viable sync path. Build a precise error so the user knows
|
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
||||||
# exactly what's missing and how to fix it.
|
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
||||||
if not caps.probed:
|
"Async weight sync skipped (NeMo Gym mode)"
|
||||||
msg = (
|
|
||||||
"could not probe the vLLM server's "
|
|
||||||
f"/openapi.json: {caps.probe_error}. "
|
|
||||||
"Verify that vLLM is reachable at "
|
|
||||||
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
|
|
||||||
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
|
|
||||||
)
|
)
|
||||||
elif has_lora:
|
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
||||||
msg = (
|
|
||||||
"the vLLM server has neither NCCL routes "
|
|
||||||
"(/init_communicator/) nor a LoRA-loading route "
|
|
||||||
"(/v1/load_lora_adapter or /set_lora_adapter/). "
|
|
||||||
"Restart vLLM with `--enable-lora --max-lora-rank N "
|
|
||||||
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
|
|
||||||
"server, or use `axolotl vllm-serve` for the "
|
|
||||||
"NCCL-capable serve module."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg = (
|
|
||||||
"the vLLM server exposes no full-parameter sync route "
|
|
||||||
"(/init_communicator/ for NCCL or /http_update_weights/ "
|
|
||||||
"for HTTP). Use `axolotl vllm-serve` (which has both) "
|
|
||||||
"or set `adapter: lora|qlora`."
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
|
|
||||||
"weight sync the trainer's gradient updates never reach the "
|
|
||||||
"rollout policy (functionally a no-op trainer)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if multi_turn:
|
if multi_turn:
|
||||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||||
|
|||||||
@@ -130,41 +130,21 @@ def start_servers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
|
def get_server_configs(head_port: int = 11000) -> dict:
|
||||||
"""Fetch the global config from the NeMo Gym head server.
|
"""Fetch the global config from the NeMo Gym head server.
|
||||||
|
|
||||||
Retries up to 3 times with exponential backoff. The default per-attempt
|
|
||||||
timeout is 30s (raised from the original 5s) because head servers can
|
|
||||||
be slow to respond when they're concurrently serving rollouts from a
|
|
||||||
prior training run. A 5s timeout was empirically too tight to survive
|
|
||||||
a kill-and-relaunch cycle.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping server_name -> server config.
|
Dict mapping server_name -> server config.
|
||||||
"""
|
"""
|
||||||
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
|
response = requests.get(
|
||||||
last_exc: Exception | None = None
|
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
||||||
for attempt in (1, 2, 3):
|
|
||||||
try:
|
|
||||||
response = requests.get(url, timeout=timeout)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = yaml.safe_load(response.text)
|
|
||||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = yaml.safe_load(result)
|
|
||||||
return result
|
|
||||||
except (requests.exceptions.RequestException, OSError) as exc:
|
|
||||||
last_exc = exc
|
|
||||||
LOG.warning(
|
|
||||||
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
|
|
||||||
attempt,
|
|
||||||
type(exc).__name__,
|
|
||||||
)
|
|
||||||
if attempt < 3:
|
|
||||||
time.sleep(2.0 * attempt)
|
|
||||||
raise RuntimeError(
|
|
||||||
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
|
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = yaml.safe_load(response.text)
|
||||||
|
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||||
|
if isinstance(result, str):
|
||||||
|
result = yaml.safe_load(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_agent_servers(
|
def get_agent_servers(
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ def _rms_norm_rope_forward_kernel(
|
|||||||
RSTD_ptr,
|
RSTD_ptr,
|
||||||
RSTD_row_stride,
|
RSTD_row_stride,
|
||||||
n_cols,
|
n_cols,
|
||||||
n_rot,
|
|
||||||
n_heads,
|
n_heads,
|
||||||
eps,
|
eps,
|
||||||
HAS_WEIGHT: tl.constexpr,
|
HAS_WEIGHT: tl.constexpr,
|
||||||
@@ -61,35 +60,28 @@ def _rms_norm_rope_forward_kernel(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Fused forward:
|
Fused forward:
|
||||||
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
|
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||||
y[..., :n_rot] = rope(x_norm[..., :n_rot])
|
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||||
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
|
|
||||||
|
|
||||||
rotate_half swaps first/second halves and negates the first, restricted
|
rotate_half swaps first/second halves and negates the first:
|
||||||
to the rotary span [0, n_rot):
|
rotate_half([a, b]) = [-b, a]
|
||||||
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
|
|
||||||
|
|
||||||
For the partial-rotary pass-through region we load cos with default 1.0
|
|
||||||
and sin with default 0.0 outside [0, n_rot), so the same formula
|
|
||||||
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
|
|
||||||
|
|
||||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||||
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
|
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||||
"""
|
"""
|
||||||
row_idx = tl.program_id(0).to(tl.int64)
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
|
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||||
cs_row_idx = row_idx // n_heads
|
cs_row_idx = row_idx // n_heads
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
mask = col_offsets < n_cols
|
mask = col_offsets < n_cols
|
||||||
rot_mask_col = col_offsets < n_rot
|
half_dim = n_cols // 2
|
||||||
half_rot = n_rot // 2
|
|
||||||
|
|
||||||
# Load input row
|
# Load input row
|
||||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
X_dtype = X_row.dtype
|
X_dtype = X_row.dtype
|
||||||
X_fp32 = X_row.to(tl.float32)
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
|
||||||
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
|
# RMSNorm: compute 1/rms
|
||||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
rstd = rsqrt(mean_sq + eps)
|
rstd = rsqrt(mean_sq + eps)
|
||||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
@@ -102,38 +94,33 @@ def _rms_norm_rope_forward_kernel(
|
|||||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||||
X_norm = X_norm * W_row
|
X_norm = X_norm * W_row
|
||||||
|
|
||||||
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
|
# RoPE: load cos/sin (broadcast across heads)
|
||||||
# cos=1, sin=0 so the formula leaves X_norm untouched.
|
|
||||||
cos_row = tl.load(
|
cos_row = tl.load(
|
||||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||||
mask=rot_mask_col,
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
sin_row = tl.load(
|
sin_row = tl.load(
|
||||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
|
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||||
mask=rot_mask_col,
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
|
|
||||||
# rotate_half within [0, n_rot):
|
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||||
# for col < half_rot: take -X_norm[col + half_rot]
|
# for col >= half_dim, take X_norm[col - half_dim]
|
||||||
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
|
|
||||||
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
|
|
||||||
rot_offsets = tl.where(
|
rot_offsets = tl.where(
|
||||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||||
)
|
)
|
||||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
rot_mask = rot_offsets < n_cols
|
||||||
X_rot = tl.load(
|
X_rot = tl.load(
|
||||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
|
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
# Re-normalize the rotated values
|
# Re-normalize the rotated values
|
||||||
X_rot_norm = X_rot * rstd
|
X_rot_norm = X_rot * rstd
|
||||||
if HAS_WEIGHT:
|
if HAS_WEIGHT:
|
||||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
|
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||||
|
tl.float32
|
||||||
|
)
|
||||||
X_rot_norm = X_rot_norm * W_rot
|
X_rot_norm = X_rot_norm * W_rot
|
||||||
|
|
||||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||||
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
|
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||||
X_rot_norm = X_rot_norm * sign
|
X_rot_norm = X_rot_norm * sign
|
||||||
|
|
||||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||||
@@ -166,21 +153,13 @@ def _rms_norm_rope_backward_kernel(
|
|||||||
dW_row_stride,
|
dW_row_stride,
|
||||||
n_rows,
|
n_rows,
|
||||||
n_cols,
|
n_cols,
|
||||||
n_rot,
|
|
||||||
n_heads,
|
n_heads,
|
||||||
rows_per_program,
|
rows_per_program,
|
||||||
HAS_WEIGHT: tl.constexpr,
|
HAS_WEIGHT: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
|
Backward for Y = RoPE(RMSNorm(X, W))
|
||||||
(`n_rot <= n_cols`).
|
|
||||||
|
|
||||||
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
|
|
||||||
output is just the normalized row, so dN[col] = dY[col] (achieved by
|
|
||||||
loading cos with default 1.0 and forcing the rotate-half contribution
|
|
||||||
to zero outside the rotary span).
|
|
||||||
|
|
||||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||||
"""
|
"""
|
||||||
row_block_id = tl.program_id(0).to(tl.int64)
|
row_block_id = tl.program_id(0).to(tl.int64)
|
||||||
@@ -188,8 +167,7 @@ def _rms_norm_rope_backward_kernel(
|
|||||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
mask = col_offsets < n_cols
|
mask = col_offsets < n_cols
|
||||||
rot_mask_col = col_offsets < n_rot
|
half_dim = n_cols // 2
|
||||||
half_rot = n_rot // 2
|
|
||||||
|
|
||||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||||
|
|
||||||
@@ -208,37 +186,33 @@ def _rms_norm_rope_backward_kernel(
|
|||||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
cos_row = tl.load(
|
cos_row = tl.load(
|
||||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||||
mask=rot_mask_col,
|
|
||||||
other=1.0,
|
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
|
|
||||||
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
|
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||||
#
|
#
|
||||||
# For col >= n_rot the formula must collapse to dN = dY (since the
|
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||||
# forward is just a pass-through). cos defaults to 1.0 above; the
|
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||||
# rotate-half contribution is masked to zero below.
|
# This is equivalent to rotating (dY * sin) because the rotation
|
||||||
|
# just permutes which elements are multiplied.
|
||||||
rot_offsets = tl.where(
|
rot_offsets = tl.where(
|
||||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||||
)
|
)
|
||||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
rot_mask = rot_offsets < n_cols
|
||||||
dY_rot = tl.load(
|
dY_rot = tl.load(
|
||||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||||
mask=rot_load_mask,
|
mask=rot_mask & mask,
|
||||||
other=0,
|
other=0,
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
sin_rot = tl.load(
|
sin_rot = tl.load(
|
||||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||||
mask=rot_load_mask,
|
mask=rot_mask & mask,
|
||||||
other=0,
|
other=0,
|
||||||
).to(tl.float32)
|
).to(tl.float32)
|
||||||
|
|
||||||
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
|
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||||
rotate_term = dY_rot * sin_rot * adj_sign
|
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||||
# Zero out rotate-half contribution outside the rotary span.
|
|
||||||
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
|
|
||||||
dN = dY_row * cos_row + rotate_term
|
|
||||||
|
|
||||||
# Pre-weight normalized: n = rstd * x
|
# Pre-weight normalized: n = rstd * x
|
||||||
n = X_row * rstd
|
n = X_row * rstd
|
||||||
@@ -267,17 +241,15 @@ def _rms_norm_rope_backward_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||||
W: (head_dim,) or None — RMSNorm weight
|
W: (head_dim,) or None — RMSNorm weight
|
||||||
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
|
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||||
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
|
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||||
eps: float
|
eps: float
|
||||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||||
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
|
|
||||||
partial rotary). Must be even and ``<= head_dim``.
|
|
||||||
Returns:
|
Returns:
|
||||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +273,6 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
|||||||
RSTD,
|
RSTD,
|
||||||
RSTD.stride(0),
|
RSTD.stride(0),
|
||||||
n_cols,
|
n_cols,
|
||||||
n_rot,
|
|
||||||
n_heads,
|
n_heads,
|
||||||
eps,
|
eps,
|
||||||
HAS_WEIGHT=has_weight,
|
HAS_WEIGHT=has_weight,
|
||||||
@@ -311,9 +282,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
|||||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_rope_backward(
|
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||||
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
|
|
||||||
):
|
|
||||||
n_rows, n_cols = dY.shape
|
n_rows, n_cols = dY.shape
|
||||||
has_weight = W is not None
|
has_weight = W is not None
|
||||||
|
|
||||||
@@ -346,7 +315,6 @@ def rms_norm_rope_backward(
|
|||||||
_dW.stride(0),
|
_dW.stride(0),
|
||||||
n_rows,
|
n_rows,
|
||||||
n_cols,
|
n_cols,
|
||||||
n_rot,
|
|
||||||
n_heads,
|
n_heads,
|
||||||
rows_per_program,
|
rows_per_program,
|
||||||
HAS_WEIGHT=has_weight,
|
HAS_WEIGHT=has_weight,
|
||||||
@@ -361,14 +329,13 @@ def rms_norm_rope_backward(
|
|||||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ensure_contiguous
|
@ensure_contiguous
|
||||||
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
|
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||||
"""
|
"""
|
||||||
X: (B*S*H, head_dim)
|
X: (B*S*H, head_dim)
|
||||||
W: (head_dim,) or None
|
W: (head_dim,) or None
|
||||||
cos: (B*S, n_rot) — broadcast across heads
|
cos: (B*S, head_dim) — broadcast across heads
|
||||||
sin: (B*S, n_rot) — broadcast across heads
|
sin: (B*S, head_dim) — broadcast across heads
|
||||||
n_heads: int
|
n_heads: int
|
||||||
n_rot: int — rotary dim (<= head_dim)
|
|
||||||
"""
|
"""
|
||||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||||
X,
|
X,
|
||||||
@@ -377,13 +344,11 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
|||||||
sin,
|
sin,
|
||||||
eps,
|
eps,
|
||||||
n_heads,
|
n_heads,
|
||||||
n_rot,
|
|
||||||
)
|
)
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
ctx.num_warps = num_warps
|
ctx.num_warps = num_warps
|
||||||
ctx.n_heads = n_heads
|
ctx.n_heads = n_heads
|
||||||
ctx.n_rot = n_rot
|
|
||||||
ctx.has_weight = W is not None
|
ctx.has_weight = W is not None
|
||||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||||
return Y
|
return Y
|
||||||
@@ -400,26 +365,21 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
|||||||
sin,
|
sin,
|
||||||
RSTD,
|
RSTD,
|
||||||
ctx.n_heads,
|
ctx.n_heads,
|
||||||
ctx.n_rot,
|
|
||||||
ctx.BLOCK_SIZE,
|
ctx.BLOCK_SIZE,
|
||||||
ctx.num_warps,
|
ctx.num_warps,
|
||||||
)
|
)
|
||||||
return dX, dW, None, None, None, None, None
|
return dX, dW, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||||
"""
|
"""
|
||||||
Apply fused RMSNorm + (partial) RoPE.
|
Apply fused RMSNorm + RoPE.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||||
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
|
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||||
must be even and ``<= head_dim``. When ``n_rot < head_dim``
|
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||||
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
|
|
||||||
(partial-rotary pass-through), matching stock Gemma 4 with
|
|
||||||
``partial_rotary_factor < 1.0``.
|
|
||||||
sin: (batch, seq_len, n_rot) — same shape as ``cos``
|
|
||||||
eps: float — RMSNorm epsilon
|
eps: float — RMSNorm epsilon
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -427,38 +387,14 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
|||||||
"""
|
"""
|
||||||
shape = x.shape # (B, S, H, D)
|
shape = x.shape # (B, S, H, D)
|
||||||
B, S, H, D = shape
|
B, S, H, D = shape
|
||||||
n_rot = cos.shape[-1]
|
|
||||||
if sin.shape[-1] != n_rot:
|
|
||||||
raise ValueError(
|
|
||||||
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
|
|
||||||
f"sin={sin.shape[-1]}"
|
|
||||||
)
|
|
||||||
if n_rot > D:
|
|
||||||
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
|
|
||||||
if n_rot % 2 != 0:
|
|
||||||
raise ValueError(f"rotary dim must be even, got {n_rot}")
|
|
||||||
|
|
||||||
# Flatten to 2D: (B*S*H, D)
|
# Flatten to 2D: (B*S*H, D)
|
||||||
x_flat = x.reshape(-1, D).contiguous()
|
x_flat = x.reshape(-1, D).contiguous()
|
||||||
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
|
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||||
# all sequences share the same rotary positions). The kernel needs a
|
# by dividing the row_idx by H to get the cos/sin row
|
||||||
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
|
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||||
# onto a single (b, s) pair, so expand-then-contiguous to materialize
|
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||||
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
|
|
||||||
if cos.shape[0] != B:
|
|
||||||
if cos.shape[0] != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
|
|
||||||
f"to x batch dim ({B})"
|
|
||||||
)
|
|
||||||
cos = cos.expand(B, S, n_rot)
|
|
||||||
sin = sin.expand(B, S, n_rot)
|
|
||||||
cos_flat = cos.reshape(B * S, n_rot).contiguous()
|
|
||||||
sin_flat = sin.reshape(B * S, n_rot).contiguous()
|
|
||||||
|
|
||||||
y_flat = FusedRMSNormRoPEFunction.apply(
|
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||||
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
|
|
||||||
)
|
|
||||||
return y_flat.view(shape)
|
return y_flat.view(shape)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -156,14 +156,6 @@ class PatchManager:
|
|||||||
# which would clobber any earlier fix.
|
# which would clobber any earlier fix.
|
||||||
self._fix_nemotron_h_conversion_mapping()
|
self._fix_nemotron_h_conversion_mapping()
|
||||||
|
|
||||||
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
|
|
||||||
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
|
|
||||||
# override needs to walk the raw model tree, which is broken by
|
|
||||||
# the post-load PEFT wrapping. The accompanying
|
|
||||||
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
|
|
||||||
# installation-time-independent, so both halves of the fix live
|
|
||||||
# cleanly in the same call even though one is instance-scoped
|
|
||||||
# and the other is module-scoped.
|
|
||||||
self._apply_gemma_hybrid_attention(model)
|
self._apply_gemma_hybrid_attention(model)
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
@@ -180,23 +172,12 @@ class PatchManager:
|
|||||||
which exceeds flash attention's supported size. This patch loads the model
|
which exceeds flash attention's supported size. This patch loads the model
|
||||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||||
|
|
||||||
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
|
|
||||||
which fixes the corresponding mask construction inside
|
|
||||||
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
|
|
||||||
override is not enough — the forward still builds a 2D FA2-format mask
|
|
||||||
at the model level and the SDPA layers crash at long context lengths
|
|
||||||
with ``RuntimeError: The expanded size of the tensor ... must match``.
|
|
||||||
"""
|
"""
|
||||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||||
return
|
return
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
|
|
||||||
# Navigate to the module that has 'layers' - varies by model structure:
|
# Navigate to the module that has 'layers' - varies by model structure:
|
||||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||||
@@ -410,14 +391,6 @@ class PatchManager:
|
|||||||
patch_qwen3_5_vlm_flash_attention()
|
patch_qwen3_5_vlm_flash_attention()
|
||||||
|
|
||||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||||
# The fused attn path is now compatible with
|
|
||||||
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
|
|
||||||
# rotary (cos.shape[-1] < head_dim) and the fused forward
|
|
||||||
# mirrors the current ``Gemma4TextAttention.forward`` API
|
|
||||||
# for shared kv (read from / write to
|
|
||||||
# ``past_key_values.shared_layers``). See
|
|
||||||
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
|
|
||||||
# for the history.
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
patch_gemma4_fused_attn,
|
patch_gemma4_fused_attn,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
if cfg.revision_of_model:
|
if cfg.revision_of_model:
|
||||||
processor_kwargs["revision"] = cfg.revision_of_model
|
processor_kwargs["revision"] = cfg.revision_of_model
|
||||||
if cfg.processor_kwargs:
|
|
||||||
processor_kwargs.update(cfg.processor_kwargs)
|
|
||||||
|
|
||||||
if cfg.tokenizer_use_mistral_common:
|
if cfg.tokenizer_use_mistral_common:
|
||||||
|
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
"""Hybrid attention mask fix for Gemma 4.
|
|
||||||
|
|
||||||
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
|
|
||||||
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
|
|
||||||
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
|
|
||||||
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
|
|
||||||
``self_attn.config``, leaving sliding-window layers on FA2.
|
|
||||||
|
|
||||||
The per-layer config override alone is insufficient, however:
|
|
||||||
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
|
|
||||||
using the **model-level** config and passes the mapped mask to each
|
|
||||||
decoder layer. With FA2 still set at the model level, the ``full_attention``
|
|
||||||
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
|
|
||||||
The global layers then fail with::
|
|
||||||
|
|
||||||
RuntimeError: The expanded size of the tensor (S) must match the existing
|
|
||||||
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
|
|
||||||
sizes: [B, S]
|
|
||||||
|
|
||||||
...when the sequence length grows past roughly 7k tokens.
|
|
||||||
|
|
||||||
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
|
|
||||||
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
|
|
||||||
the original in ``masking_utils``. The wrapper forces
|
|
||||||
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
|
|
||||||
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
|
|
||||||
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
|
|
||||||
alone, so sliding-window layers continue to receive FA2-format masks.
|
|
||||||
|
|
||||||
The patch is idempotent. Install once per process, before any Gemma 4
|
|
||||||
forward pass runs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
_PATCH_APPLIED = False
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma4_hybrid_mask() -> bool:
|
|
||||||
"""Install the Gemma 4 hybrid-attention mask fix.
|
|
||||||
|
|
||||||
Returns ``True`` if the patch was installed (or was already installed),
|
|
||||||
``False`` if the target module could not be imported (e.g. transformers
|
|
||||||
version predates Gemma 4) — in which case nothing is done and the
|
|
||||||
caller can continue unaffected.
|
|
||||||
"""
|
|
||||||
global _PATCH_APPLIED
|
|
||||||
if _PATCH_APPLIED:
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.gemma4 import modeling_gemma4
|
|
||||||
except ImportError:
|
|
||||||
LOG.debug(
|
|
||||||
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
|
|
||||||
"skipping. This is fine for non-Gemma4 training."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not hasattr(modeling_gemma4, "create_causal_mask"):
|
|
||||||
LOG.warning(
|
|
||||||
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
|
|
||||||
"binding, skipping. Transformers API may have changed."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
original = modeling_gemma4.create_causal_mask
|
|
||||||
|
|
||||||
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
|
|
||||||
"""Wrapper that forces SDPA format for the full-attention mask.
|
|
||||||
|
|
||||||
The global layers were patched to SDPA by
|
|
||||||
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
|
|
||||||
original ``create_causal_mask`` dispatches on
|
|
||||||
``config._attn_implementation``; we shadow that with a local
|
|
||||||
override.
|
|
||||||
"""
|
|
||||||
sdpa_config = copy.copy(config)
|
|
||||||
sdpa_config._attn_implementation = "sdpa"
|
|
||||||
return original(sdpa_config, *args, **kwargs)
|
|
||||||
|
|
||||||
# Preserve the original reference on the wrapper for tests / teardown.
|
|
||||||
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
|
|
||||||
_PATCH_APPLIED = True
|
|
||||||
LOG.info(
|
|
||||||
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
|
|
||||||
"force SDPA-format masks for full-attention layers"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def unpatch_gemma4_hybrid_mask() -> None:
|
|
||||||
"""Restore the original ``create_causal_mask``. Useful for tests."""
|
|
||||||
global _PATCH_APPLIED
|
|
||||||
if not _PATCH_APPLIED:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
from transformers.models.gemma4 import modeling_gemma4
|
|
||||||
except ImportError:
|
|
||||||
_PATCH_APPLIED = False
|
|
||||||
return
|
|
||||||
current = modeling_gemma4.create_causal_mask
|
|
||||||
original = getattr(current, "_axolotl_original", None)
|
|
||||||
if original is not None:
|
|
||||||
modeling_gemma4.create_causal_mask = original
|
|
||||||
_PATCH_APPLIED = False
|
|
||||||
@@ -24,15 +24,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||||
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
|
||||||
# language-side module is separated from the vision tower. Try
|
|
||||||
# both names before giving up.
|
|
||||||
mlp_cls = getattr(
|
|
||||||
module,
|
|
||||||
f"{model_cls_prefix}MLP",
|
|
||||||
None,
|
|
||||||
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
|
||||||
|
|
||||||
if use_original_mlp:
|
if use_original_mlp:
|
||||||
mlp_forward = mlp_cls.forward
|
mlp_forward = mlp_cls.forward
|
||||||
|
|||||||
@@ -407,10 +407,7 @@ def selective_log_softmax(logits, index) -> torch.Tensor:
|
|||||||
K = index.shape[-1]
|
K = index.shape[-1]
|
||||||
original_index_shape = index.shape
|
original_index_shape = index.shape
|
||||||
|
|
||||||
try:
|
flat_logits = logits.reshape(-1, V).contiguous()
|
||||||
flat_logits = logits.view(-1, V)
|
|
||||||
except RuntimeError:
|
|
||||||
flat_logits = logits.reshape(-1, V).contiguous()
|
|
||||||
flat_index = index.reshape(-1, K).contiguous()
|
flat_index = index.reshape(-1, K).contiguous()
|
||||||
|
|
||||||
BLOCK_V = 4096
|
BLOCK_V = 4096
|
||||||
|
|||||||
@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||||
try:
|
try:
|
||||||
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
|
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||||
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
|
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
@@ -1004,13 +1004,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if tools is None:
|
if tools is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Some datasets have tools set to str
|
|
||||||
if isinstance(tools, str):
|
|
||||||
try:
|
|
||||||
tools = json.loads(tools)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
|
|
||||||
raise
|
|
||||||
if isinstance(tools, list):
|
if isinstance(tools, list):
|
||||||
# Process each tool to handle JSON string parameters
|
# Process each tool to handle JSON string parameters
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@@ -1041,22 +1034,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if messages is None:
|
if messages is None:
|
||||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||||
|
|
||||||
if isinstance(messages, str):
|
|
||||||
try:
|
|
||||||
messages = json.loads(messages)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
LOG.error(f"Error parsing messages as JSON. Error: {e}")
|
|
||||||
raise
|
|
||||||
assert isinstance(messages, list), (
|
|
||||||
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extra check here to make sure decoded json is a list of dicts.
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
assert isinstance(message, dict), (
|
|
||||||
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -320,15 +320,6 @@ def main(script_args: ScriptArguments):
|
|||||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||||
active_lora: dict = {"request": None}
|
active_lora: dict = {"request": None}
|
||||||
|
|
||||||
# Serializes access to the worker pipe. The underlying
|
|
||||||
# multiprocessing.Connection is a single full-duplex stream shared
|
|
||||||
# across all HTTP handlers; concurrent requests interleave bytes on
|
|
||||||
# the wire and corrupt the pickle framing (seen as
|
|
||||||
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
|
|
||||||
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
|
|
||||||
# the round-trip so only one inflight call at a time per pipe.
|
|
||||||
worker_pipe_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# LoRA-specific endpoints
|
# LoRA-specific endpoints
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -640,150 +631,6 @@ def main(script_args: ScriptArguments):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
|
||||||
async def openai_completions(request_body: dict):
|
|
||||||
"""OpenAI-compatible text-completions endpoint.
|
|
||||||
|
|
||||||
Accepts either a string ``prompt`` or a list-of-int
|
|
||||||
``prompt_token_ids`` (as the text-completions spec allows). Routes
|
|
||||||
to the internal vLLM generate method with the active LoRA adapter
|
|
||||||
and returns an OpenAI /v1/completions-shaped response including
|
|
||||||
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
|
|
||||||
``generation_log_probs`` for NeMo Gym agents that need raw
|
|
||||||
tokens + logprobs.
|
|
||||||
"""
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
prompt_raw = request_body.get("prompt")
|
|
||||||
temperature = request_body.get("temperature", 1.0)
|
|
||||||
max_tokens = request_body.get("max_tokens", 512)
|
|
||||||
top_p = request_body.get("top_p", 1.0)
|
|
||||||
n = request_body.get("n", 1)
|
|
||||||
logprobs = request_body.get("logprobs") or 0
|
|
||||||
stop_token_ids = request_body.get("stop_token_ids") or None
|
|
||||||
|
|
||||||
# Accept either a string or a list[int] token id prompt. Lists
|
|
||||||
# must contain ints only (raise on lists of strings so callers get
|
|
||||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
|
||||||
# rare case callers pass a single-prompt batch.
|
|
||||||
if (
|
|
||||||
isinstance(prompt_raw, list)
|
|
||||||
and prompt_raw
|
|
||||||
and isinstance(prompt_raw[0], list)
|
|
||||||
):
|
|
||||||
prompt_raw = prompt_raw[0]
|
|
||||||
|
|
||||||
prompt_dict: dict[str, Any] = {}
|
|
||||||
if isinstance(prompt_raw, list):
|
|
||||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
|
||||||
elif isinstance(prompt_raw, str):
|
|
||||||
prompt_dict = {"prompt": prompt_raw}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"error": {
|
|
||||||
"message": ("prompt must be a string or a list of token ids"),
|
|
||||||
"type": "invalid_request",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
generation_kwargs: dict[str, Any] = {
|
|
||||||
"n": n,
|
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"logprobs": logprobs,
|
|
||||||
}
|
|
||||||
if stop_token_ids:
|
|
||||||
generation_kwargs["stop_token_ids"] = stop_token_ids
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
|
||||||
)
|
|
||||||
|
|
||||||
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
|
|
||||||
|
|
||||||
# Hold the pipe lock across send+recv — concurrent requests would
|
|
||||||
# otherwise interleave pickle frames on the worker connection.
|
|
||||||
async with worker_pipe_lock:
|
|
||||||
for conn, chunk in zip(connections, chunked, strict=True):
|
|
||||||
if not chunk:
|
|
||||||
chunk = [{"prompt": "<placeholder>"}]
|
|
||||||
kwargs = {
|
|
||||||
"prompts": chunk,
|
|
||||||
"sampling_params": sampling_params,
|
|
||||||
"lora_request": active_lora["request"],
|
|
||||||
}
|
|
||||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
all_outputs = await asyncio.gather(
|
|
||||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
|
||||||
for o in all_outputs:
|
|
||||||
if isinstance(o, dict) and "error" in o:
|
|
||||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
|
||||||
all_outputs = list(chain.from_iterable(all_outputs))
|
|
||||||
|
|
||||||
if not all_outputs:
|
|
||||||
return {"choices": [], "model": script_args.model}
|
|
||||||
|
|
||||||
choices = []
|
|
||||||
for i, output in enumerate(all_outputs):
|
|
||||||
for j, out in enumerate(output.outputs):
|
|
||||||
text = out.text
|
|
||||||
# OpenAI-style `logprobs` block for text-completions:
|
|
||||||
# { "tokens": [...], "token_logprobs": [...] }
|
|
||||||
lp_block = None
|
|
||||||
if out.logprobs:
|
|
||||||
tokens_str: list[str] = []
|
|
||||||
token_lps: list[float] = []
|
|
||||||
for step in out.logprobs:
|
|
||||||
chosen = next(iter(step.values()))
|
|
||||||
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
|
|
||||||
token_lps.append(float(chosen.logprob))
|
|
||||||
lp_block = {
|
|
||||||
"tokens": tokens_str,
|
|
||||||
"token_logprobs": token_lps,
|
|
||||||
}
|
|
||||||
|
|
||||||
choice = {
|
|
||||||
"index": i * n + j,
|
|
||||||
"text": text,
|
|
||||||
"finish_reason": "stop"
|
|
||||||
if out.finish_reason == "stop"
|
|
||||||
else "length",
|
|
||||||
"logprobs": lp_block,
|
|
||||||
# NeMo-Gym / retrace agent extras — preserved on the
|
|
||||||
# choice so callers with raw-token pipelines don't
|
|
||||||
# have to re-tokenize.
|
|
||||||
"prompt_token_ids": output.prompt_token_ids,
|
|
||||||
"generation_token_ids": list(out.token_ids),
|
|
||||||
"generation_log_probs": (
|
|
||||||
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
|
|
||||||
if out.logprobs
|
|
||||||
else []
|
|
||||||
),
|
|
||||||
}
|
|
||||||
choices.append(choice)
|
|
||||||
|
|
||||||
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
|
||||||
completion_tokens = sum(
|
|
||||||
len(out.token_ids) for o in all_outputs for out in o.outputs
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
|
|
||||||
"object": "text_completion",
|
|
||||||
"model": script_args.model,
|
|
||||||
"choices": choices,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": prompt_tokens + completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||||
|
|
||||||
@app.post("/init_communicator/")
|
@app.post("/init_communicator/")
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from .batching import (
|
|||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from .dpo import AxolotlDPODataCollatorWithPadding
|
|
||||||
from .mamba import MambaDataCollator
|
from .mamba import MambaDataCollator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -14,6 +13,5 @@ __all__ = [
|
|||||||
"BatchSamplerDataCollatorForSeq2Seq",
|
"BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
"V2BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
||||||
"AxolotlDPODataCollatorWithPadding",
|
|
||||||
"MambaDataCollator",
|
"MambaDataCollator",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,128 +0,0 @@
|
|||||||
"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support.
|
|
||||||
|
|
||||||
Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths
|
|
||||||
up to a fixed multiple. This stabilizes Triton autotune caches for kernels
|
|
||||||
that key on sequence length (e.g. fla's linear attention kernels used by
|
|
||||||
Qwen3.5), which otherwise re-autotune on every distinct batch length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from trl.experimental.utils import DPODataCollatorWithPadding
|
|
||||||
from trl.trainer.utils import pad
|
|
||||||
|
|
||||||
|
|
||||||
def _round_up(length: int, multiple: int) -> int:
|
|
||||||
return ((length + multiple - 1) // multiple) * multiple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding):
|
|
||||||
"""DPO data collator that pads to a multiple of ``pad_to_multiple_of``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pad_token_id: Tokenizer pad token id (inherited).
|
|
||||||
is_encoder_decoder: Whether the model is encoder-decoder (inherited).
|
|
||||||
pad_to_multiple_of: If set, padded lengths are rounded up to this
|
|
||||||
multiple. Helps stabilize Triton autotune caches.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pad_to_multiple_of: int | None = None
|
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
|
||||||
pad_to_mult = self.pad_to_multiple_of
|
|
||||||
|
|
||||||
padded_batch: dict[str, Any] = {}
|
|
||||||
for k in features[0].keys():
|
|
||||||
if k.endswith(
|
|
||||||
("_input_ids", "_attention_mask", "_labels", "_pixel_values")
|
|
||||||
):
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
if k.endswith("_pixel_values"):
|
|
||||||
to_pad = [
|
|
||||||
torch.tensor(ex[k], dtype=torch.float32) for ex in features
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
|
||||||
|
|
||||||
if k.startswith("prompt") and k.endswith("input_ids"):
|
|
||||||
if self.pad_token_id is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
|
||||||
)
|
|
||||||
padding_value = self.pad_token_id
|
|
||||||
elif k.endswith("_attention_mask"):
|
|
||||||
padding_value = 0
|
|
||||||
elif k.endswith("_pixel_values"):
|
|
||||||
padding_value = 0
|
|
||||||
elif (
|
|
||||||
k.startswith(("chosen", "rejected", "completion"))
|
|
||||||
or "decoder" in k
|
|
||||||
):
|
|
||||||
padding_value = -100
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected key in batch '{k}'")
|
|
||||||
|
|
||||||
padded = pad_sequence(
|
|
||||||
to_pad, batch_first=True, padding_value=padding_value
|
|
||||||
)
|
|
||||||
if pad_to_mult:
|
|
||||||
cur = padded.shape[1]
|
|
||||||
target = _round_up(cur, pad_to_mult)
|
|
||||||
if target > cur:
|
|
||||||
extra = target - cur
|
|
||||||
pad_shape = list(padded.shape)
|
|
||||||
pad_shape[1] = extra
|
|
||||||
filler = torch.full(
|
|
||||||
pad_shape,
|
|
||||||
padding_value,
|
|
||||||
dtype=padded.dtype,
|
|
||||||
device=padded.device,
|
|
||||||
)
|
|
||||||
padded = torch.cat([padded, filler], dim=1)
|
|
||||||
padded_batch[k] = padded
|
|
||||||
else:
|
|
||||||
if k.endswith("_input_ids"):
|
|
||||||
if self.pad_token_id is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
|
||||||
)
|
|
||||||
padding_value = self.pad_token_id
|
|
||||||
elif k.endswith("_labels"):
|
|
||||||
padding_value = -100
|
|
||||||
elif k.endswith("_attention_mask"):
|
|
||||||
padding_value = 0
|
|
||||||
elif k.endswith("_pixel_values"):
|
|
||||||
padding_value = 0
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected key in batch '{k}'")
|
|
||||||
|
|
||||||
padding_side = (
|
|
||||||
"left"
|
|
||||||
if k in ("prompt_input_ids", "prompt_attention_mask")
|
|
||||||
else "right"
|
|
||||||
)
|
|
||||||
|
|
||||||
dtype = (
|
|
||||||
torch.float32 if k.endswith("_pixel_values") else torch.int64
|
|
||||||
)
|
|
||||||
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
|
|
||||||
|
|
||||||
# trl.pad() natively supports pad_to_multiple_of
|
|
||||||
padded_batch[k] = pad(
|
|
||||||
to_pad,
|
|
||||||
padding_value=padding_value,
|
|
||||||
padding_side=padding_side,
|
|
||||||
pad_to_multiple_of=pad_to_mult,
|
|
||||||
)
|
|
||||||
elif k.endswith("_logps"):
|
|
||||||
padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
|
||||||
else:
|
|
||||||
padded_batch[k] = [ex[k] for ex in features]
|
|
||||||
|
|
||||||
return padded_batch
|
|
||||||
@@ -309,16 +309,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
|
||||||
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "List of DPO losses to use."},
|
|
||||||
)
|
|
||||||
|
|
||||||
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Weights for each DPO loss."},
|
|
||||||
)
|
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[
|
list[
|
||||||
@@ -673,12 +663,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pad_to_multiple_of: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": ("Pad each batch to a multiple of this value.")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
curriculum_sampling: bool | None = Field(
|
curriculum_sampling: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1016,7 +1000,7 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Literal["auto"] | bool | None = Field(
|
torch_compile: Literal["auto"] | bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use torch.compile and which backend to use."
|
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
torch_compile_backend: str | None = Field(
|
torch_compile_backend: str | None = Field(
|
||||||
|
|||||||
@@ -64,12 +64,6 @@ class ModelInputConfig(BaseModel):
|
|||||||
processor_type: str | None = Field(
|
processor_type: str | None = Field(
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
processor_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tokenizer_save_jinja_files: bool | None = Field(
|
tokenizer_save_jinja_files: bool | None = Field(
|
||||||
default=True, # match the default behavior from transformers
|
default=True, # match the default behavior from transformers
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -113,22 +107,6 @@ class ModelInputConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return trust_remote_code
|
return trust_remote_code
|
||||||
|
|
||||||
@field_validator("processor_kwargs")
|
|
||||||
@classmethod
|
|
||||||
def reject_reserved_processor_kwargs(cls, processor_kwargs):
|
|
||||||
if not processor_kwargs:
|
|
||||||
return processor_kwargs
|
|
||||||
reserved = {"revision", "trust_remote_code"}
|
|
||||||
conflicts = reserved.intersection(processor_kwargs)
|
|
||||||
if conflicts:
|
|
||||||
raise ValueError(
|
|
||||||
"Do not set reserved keys "
|
|
||||||
f"{sorted(conflicts)} inside `processor_kwargs`; "
|
|
||||||
"use the top-level `revision_of_model` / `trust_remote_code` "
|
|
||||||
"config keys instead."
|
|
||||||
)
|
|
||||||
return processor_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
class ModelOutputConfig(BaseModel):
|
||||||
"""model save configuration subset"""
|
"""model save configuration subset"""
|
||||||
|
|||||||
@@ -578,11 +578,6 @@ class TrainingValidationMixin:
|
|||||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.get("processor_kwargs"):
|
|
||||||
raise ValueError(
|
|
||||||
"processor_kwargs is not supported with mistral-common tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -765,122 +760,6 @@ class RLValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_dpo(cls, data):
|
|
||||||
dpo_loss_type = data.get("dpo_loss_type")
|
|
||||||
dpo_loss_weights = data.get("dpo_loss_weights")
|
|
||||||
rl = data.get("rl")
|
|
||||||
|
|
||||||
if rl == "ipo":
|
|
||||||
LOG.warning(
|
|
||||||
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if rl == "dpo":
|
|
||||||
if dpo_loss_weights is not None and dpo_loss_type is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
dpo_loss_type is not None
|
|
||||||
and dpo_loss_weights is not None
|
|
||||||
and len(dpo_loss_type) != len(dpo_loss_weights)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
|
|
||||||
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
|
|
||||||
)
|
|
||||||
elif dpo_loss_type is not None or dpo_loss_weights is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
|
|
||||||
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_grpo_batch_size_divisibility(cls, data):
|
|
||||||
"""Surface GRPO batch-shape mismatches at config-parse time.
|
|
||||||
|
|
||||||
TRL's GRPOTrainer requires that the per-step generation batch size be
|
|
||||||
evenly divisible by ``num_generations`` so that every prompt can be
|
|
||||||
replicated exactly ``num_generations`` times. The runtime check inside
|
|
||||||
``GRPOTrainer.__init__`` only fires after the model has been loaded —
|
|
||||||
too late and too cryptic for the user. We replicate the check here so
|
|
||||||
the failure is immediate and actionable.
|
|
||||||
|
|
||||||
Also enforces:
|
|
||||||
- ``num_generations >= 2`` (group-relative advantage needs variance)
|
|
||||||
- ``effective_gbs >= num_generations * world_size`` when capabilities
|
|
||||||
indicate multiple ranks (each rank needs at least one full group)
|
|
||||||
"""
|
|
||||||
if data.get("rl") != "grpo":
|
|
||||||
return data
|
|
||||||
|
|
||||||
trl_cfg = data.get("trl") or {}
|
|
||||||
num_gen = trl_cfg.get("num_generations")
|
|
||||||
if num_gen is None:
|
|
||||||
# TRL's own default is 8 — but if the user didn't set it, we
|
|
||||||
# don't have enough info to validate anything. Let TRL's own
|
|
||||||
# init handle the default-vs-batch interaction.
|
|
||||||
return data
|
|
||||||
if num_gen < 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
|
|
||||||
"With num_generations=1, every group has zero advantage and "
|
|
||||||
"the policy never updates."
|
|
||||||
)
|
|
||||||
|
|
||||||
explicit_gbs = trl_cfg.get("generation_batch_size")
|
|
||||||
if explicit_gbs is not None:
|
|
||||||
effective_gbs = int(explicit_gbs)
|
|
||||||
gbs_source = "trl.generation_batch_size"
|
|
||||||
else:
|
|
||||||
mb = data.get("micro_batch_size") or 1
|
|
||||||
ga = data.get("gradient_accumulation_steps") or 1
|
|
||||||
effective_gbs = int(mb) * int(ga)
|
|
||||||
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
|
|
||||||
|
|
||||||
if effective_gbs % num_gen != 0:
|
|
||||||
# Suggest the smallest GA bump that fixes it for the common case
|
|
||||||
# where the user hasn't set generation_batch_size explicitly.
|
|
||||||
hint = ""
|
|
||||||
if explicit_gbs is None:
|
|
||||||
from math import gcd
|
|
||||||
|
|
||||||
mb_val = int(data.get("micro_batch_size") or 1)
|
|
||||||
# smallest GA such that mb*GA is a multiple of num_gen
|
|
||||||
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
|
|
||||||
suggested_ga = lcm // mb_val
|
|
||||||
hint = (
|
|
||||||
f" Smallest fix: set `gradient_accumulation_steps: "
|
|
||||||
f"{suggested_ga}` (so micro_batch_size * GA = "
|
|
||||||
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"GRPO: generation batch size must be divisible by "
|
|
||||||
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
|
|
||||||
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Multi-rank check: each rank must receive at least one full group
|
|
||||||
# per step. Without `capabilities` populated yet (mode='before'), we
|
|
||||||
# fall back to user-set distributed fields.
|
|
||||||
world_size = (
|
|
||||||
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
|
|
||||||
)
|
|
||||||
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"GRPO with world_size={world_size} requires effective_gbs "
|
|
||||||
f">= num_generations * world_size = {num_gen * world_size}, "
|
|
||||||
f"got {effective_gbs}. Increase gradient_accumulation_steps "
|
|
||||||
f"or micro_batch_size."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizationValidationMixin:
|
class OptimizationValidationMixin:
|
||||||
"""Validation methods related to optimization and performance."""
|
"""Validation methods related to optimization and performance."""
|
||||||
|
|||||||
@@ -216,197 +216,5 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
|||||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||||
|
|
||||||
|
|
||||||
class TestVllmLoraSyncPatch(unittest.TestCase):
|
|
||||||
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
|
||||||
|
|
||||||
These tests exercise the patch-installation branch in isolation. They build
|
|
||||||
a stub trainer with just enough attributes to look like
|
|
||||||
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
|
||||||
|
|
||||||
Background — there are two correct behaviors and we historically had a bug
|
|
||||||
where both modes used the same one:
|
|
||||||
|
|
||||||
- Async prefetch ON: the BG generation thread can't safely call
|
|
||||||
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
|
||||||
the main thread via ``_maybe_sync_vllm_weights``.
|
|
||||||
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
|
||||||
calls ``sync_weights`` once per step boundary on the main thread. We
|
|
||||||
wire that hook directly to ``_sync_lora_adapter`` because
|
|
||||||
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
|
||||||
|
|
||||||
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
|
||||||
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
|
||||||
AsyncGRPOTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
class FakeArgs:
|
|
||||||
pass
|
|
||||||
|
|
||||||
args = FakeArgs()
|
|
||||||
args.vllm_lora_sync = vllm_lora_sync
|
|
||||||
args.async_prefetch = async_prefetch
|
|
||||||
|
|
||||||
class FakeVllmGen:
|
|
||||||
sync_weights = staticmethod(lambda: None)
|
|
||||||
model = MagicMock()
|
|
||||||
|
|
||||||
# Use object.__new__ so we don't run __init__ (which needs a real
|
|
||||||
# model, dataset, etc.). We only need the `_generate_single_turn`
|
|
||||||
# method's patch branch to run, so we set up the minimum state.
|
|
||||||
trainer = object.__new__(AsyncGRPOTrainer)
|
|
||||||
trainer.args = args
|
|
||||||
trainer.use_vllm = True
|
|
||||||
trainer.vllm_generation = FakeVllmGen()
|
|
||||||
trainer._patched_sync_weights = False
|
|
||||||
# Spy on _sync_lora_adapter so we can assert it's the function the
|
|
||||||
# hook delegates to in sync mode.
|
|
||||||
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
|
||||||
trainer._sync_peft_weights_no_merge = MagicMock(
|
|
||||||
name="_sync_peft_weights_no_merge_spy"
|
|
||||||
)
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _run_patch_branch(trainer):
|
|
||||||
"""Execute just the sync_weights-patching branch in isolation.
|
|
||||||
|
|
||||||
We can't easily call the real ``_generate_single_turn`` because it
|
|
||||||
does a full vLLM generate. Instead we copy the exact branch out of
|
|
||||||
the source so the test verifies the same logic the trainer runs.
|
|
||||||
"""
|
|
||||||
if not getattr(trainer, "_patched_sync_weights", False):
|
|
||||||
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
|
||||||
if getattr(trainer.args, "vllm_lora_sync", False):
|
|
||||||
if getattr(trainer.args, "async_prefetch", False):
|
|
||||||
trainer.vllm_generation.sync_weights = lambda: None
|
|
||||||
else:
|
|
||||||
sync_helper = trainer._sync_lora_adapter
|
|
||||||
|
|
||||||
def _lora_filesystem_sync():
|
|
||||||
sync_helper()
|
|
||||||
|
|
||||||
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
|
||||||
trainer._patched_sync_weights = True
|
|
||||||
|
|
||||||
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
|
||||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
|
||||||
self._run_patch_branch(trainer)
|
|
||||||
|
|
||||||
assert trainer._patched_sync_weights is True
|
|
||||||
# Trigger the patched hook — it must call _sync_lora_adapter.
|
|
||||||
trainer.vllm_generation.sync_weights()
|
|
||||||
trainer._sync_lora_adapter.assert_called_once()
|
|
||||||
|
|
||||||
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
|
||||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
|
||||||
self._run_patch_branch(trainer)
|
|
||||||
|
|
||||||
assert trainer._patched_sync_weights is True
|
|
||||||
# Hook must be a no-op so BG-thread generation doesn't fight the
|
|
||||||
# main-thread optimizer step over the model weights.
|
|
||||||
trainer.vllm_generation.sync_weights()
|
|
||||||
trainer._sync_lora_adapter.assert_not_called()
|
|
||||||
|
|
||||||
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
|
||||||
"""Installing the patch should not pre-emptively sync."""
|
|
||||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
|
||||||
self._run_patch_branch(trainer)
|
|
||||||
# _sync_lora_adapter should only be called when the patched hook
|
|
||||||
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
|
||||||
trainer._sync_lora_adapter.assert_not_called()
|
|
||||||
|
|
||||||
def test_patch_is_idempotent(self):
|
|
||||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
|
||||||
self._run_patch_branch(trainer)
|
|
||||||
first_hook = trainer.vllm_generation.sync_weights
|
|
||||||
# Second call must not re-patch (otherwise we'd lose the original).
|
|
||||||
self._run_patch_branch(trainer)
|
|
||||||
assert trainer.vllm_generation.sync_weights is first_hook
|
|
||||||
|
|
||||||
|
|
||||||
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
|
||||||
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
|
||||||
|
|
||||||
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
|
||||||
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
|
||||||
is the default for any config that doesn't explicitly set it). We now
|
|
||||||
fall back to interval=1 so unset means "sync every step", matching the
|
|
||||||
behavior of TRL's own ``_generate_single_turn``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_stub_trainer(interval, async_prefetch):
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
|
||||||
AsyncGRPOTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
class FakeArgs:
|
|
||||||
pass
|
|
||||||
|
|
||||||
args = FakeArgs()
|
|
||||||
args.async_prefetch = async_prefetch
|
|
||||||
args.vllm_sync_interval = interval
|
|
||||||
args.vllm_lora_sync = True
|
|
||||||
|
|
||||||
class FakeState:
|
|
||||||
global_step = 1
|
|
||||||
|
|
||||||
trainer = object.__new__(AsyncGRPOTrainer)
|
|
||||||
trainer.args = args
|
|
||||||
trainer.use_vllm = True
|
|
||||||
trainer.state = FakeState()
|
|
||||||
trainer._last_synced_step = 0
|
|
||||||
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
def test_interval_none_in_async_mode_does_not_crash(self):
|
|
||||||
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
|
||||||
AsyncGRPOTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should not raise TypeError — defaults to every-step sync
|
|
||||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
|
||||||
trainer._sync_lora_adapter.assert_called_once()
|
|
||||||
|
|
||||||
def test_sync_mode_drives_sync(self):
|
|
||||||
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
|
||||||
|
|
||||||
The previous behavior (early return when ``not async_prefetch``)
|
|
||||||
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
|
||||||
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
|
||||||
where the data producer bypasses ``_generate_single_turn``
|
|
||||||
entirely. Without this trigger no sync ever happens and the
|
|
||||||
trainer becomes a no-op.
|
|
||||||
"""
|
|
||||||
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
|
||||||
AsyncGRPOTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
|
||||||
trainer._sync_lora_adapter.assert_called_once()
|
|
||||||
|
|
||||||
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
|
||||||
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
|
||||||
from axolotl.core.trainers.grpo.async_trainer import (
|
|
||||||
AsyncGRPOTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
|
||||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
|
||||||
trainer._sync_lora_adapter.assert_not_called()
|
|
||||||
|
|
||||||
# global_step=4 → 4 % 4 == 0 → sync
|
|
||||||
trainer.state.global_step = 4
|
|
||||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
|
||||||
trainer._sync_lora_adapter.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -96,8 +96,6 @@ def fixture_dpo_cfg(base_cfg):
|
|||||||
"dpo_use_weighting": True,
|
"dpo_use_weighting": True,
|
||||||
"dpo_label_smoothing": 0.1,
|
"dpo_label_smoothing": 0.1,
|
||||||
"beta": 0.1, # DPO beta
|
"beta": 0.1, # DPO beta
|
||||||
"dpo_loss_type": ["sigmoid", "sft"],
|
|
||||||
"dpo_loss_weights": [1.0, 0.5],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return cfg
|
return cfg
|
||||||
@@ -166,8 +164,7 @@ def fixture_ipo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": RLType.DPO,
|
"rl": RLType.IPO,
|
||||||
"dpo_loss_type": ["ipo"],
|
|
||||||
"dpo_label_smoothing": 0,
|
"dpo_label_smoothing": 0,
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
}
|
}
|
||||||
@@ -303,8 +300,6 @@ class TestHFRLTrainerBuilder:
|
|||||||
assert training_arguments.use_weighting is True
|
assert training_arguments.use_weighting is True
|
||||||
assert training_arguments.label_smoothing == 0.1
|
assert training_arguments.label_smoothing == 0.1
|
||||||
assert training_arguments.precompute_ref_log_probs is True
|
assert training_arguments.precompute_ref_log_probs is True
|
||||||
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
|
||||||
assert training_arguments.loss_weights == [1.0, 0.5]
|
|
||||||
|
|
||||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||||
|
|||||||
@@ -116,58 +116,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_rpo(self, temp_dir):
|
|
||||||
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
|
|
||||||
# replaces loss_type="rpo", rpo_alpha=alpha.
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 64,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.1,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"rl": "dpo",
|
|
||||||
"dpo_loss_type": ["sigmoid", "sft"],
|
|
||||||
"dpo_loss_weights": [1.0, 1.0],
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
|
||||||
"type": "chatml.ultra",
|
|
||||||
"split": "train",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "paged_adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 20,
|
|
||||||
"save_steps": 10,
|
|
||||||
"warmup_steps": 5,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
|
||||||
|
|
||||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_kto_pair_lora(self, temp_dir):
|
def test_kto_pair_lora(self, temp_dir):
|
||||||
@@ -233,8 +181,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
"rl": "dpo",
|
"rl": "ipo",
|
||||||
"dpo_loss_type": ["ipo"],
|
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
|
|||||||
@@ -361,329 +361,6 @@ class TestPluginDefaults(unittest.TestCase):
|
|||||||
assert cfg.dataloader_num_workers == 0
|
assert cfg.dataloader_num_workers == 0
|
||||||
|
|
||||||
|
|
||||||
class TestSelectWeightSyncTransport(unittest.TestCase):
|
|
||||||
"""Pure-logic table tests for ``select_weight_sync_transport``."""
|
|
||||||
|
|
||||||
def _caps(self, **kwargs):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
|
||||||
|
|
||||||
c = VLLMWeightSyncCapabilities(probed=True)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(c, k, v)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def test_lora_with_native_endpoint(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps(lora_filesystem=True)
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
|
||||||
== "lora_filesystem"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_lora_with_axolotl_endpoint(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps(lora_axolotl=True)
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
|
||||||
== "lora_filesystem"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps(nccl=True)
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
|
|
||||||
== "nccl"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_full_param_prefers_nccl(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps(nccl=True, http_full=True)
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(
|
|
||||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
|
||||||
)
|
|
||||||
== "nccl"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_full_param_falls_back_to_http(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps(http_full=True)
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(
|
|
||||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
|
||||||
)
|
|
||||||
== "http_full"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_full_param_no_routes_returns_none(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps() # all False
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(
|
|
||||||
caps, has_lora=False, vllm_lora_sync_pref=False
|
|
||||||
)
|
|
||||||
== "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_lora_no_routes_returns_none(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
|
|
||||||
|
|
||||||
caps = self._caps()
|
|
||||||
assert (
|
|
||||||
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
|
|
||||||
== "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProbeVllmWeightSync(unittest.TestCase):
|
|
||||||
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
|
|
||||||
|
|
||||||
def test_stock_vllm_with_lora_enabled(self):
|
|
||||||
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
|
||||||
|
|
||||||
spec = {
|
|
||||||
"paths": {
|
|
||||||
"/v1/models": {"get": {}},
|
|
||||||
"/v1/load_lora_adapter": {"post": {}},
|
|
||||||
"/v1/unload_lora_adapter": {"post": {}},
|
|
||||||
"/v1/completions": {"post": {}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with patch("requests.get") as mock_get:
|
|
||||||
mock_get.return_value.raise_for_status = lambda: None
|
|
||||||
mock_get.return_value.json = lambda: spec
|
|
||||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
|
||||||
|
|
||||||
assert caps.probed is True
|
|
||||||
assert caps.lora_filesystem is True
|
|
||||||
assert caps.lora_axolotl is False
|
|
||||||
assert caps.nccl is False
|
|
||||||
assert caps.http_full is False
|
|
||||||
|
|
||||||
def test_axolotl_serve_lora_full_capabilities(self):
|
|
||||||
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
|
||||||
|
|
||||||
spec = {
|
|
||||||
"paths": {
|
|
||||||
"/init_communicator/": {"post": {}},
|
|
||||||
"/update_named_param/": {"post": {}},
|
|
||||||
"/batch_update_named_params/": {"post": {}},
|
|
||||||
"/set_lora_adapter/": {"post": {}},
|
|
||||||
"/clear_lora_adapter/": {"post": {}},
|
|
||||||
"/http_update_weights/": {"post": {}},
|
|
||||||
"/v1/load_lora_adapter": {"post": {}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with patch("requests.get") as mock_get:
|
|
||||||
mock_get.return_value.raise_for_status = lambda: None
|
|
||||||
mock_get.return_value.json = lambda: spec
|
|
||||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
|
||||||
|
|
||||||
assert caps.probed is True
|
|
||||||
assert caps.nccl is True
|
|
||||||
assert caps.lora_axolotl is True
|
|
||||||
assert caps.lora_filesystem is True
|
|
||||||
assert caps.http_full is True
|
|
||||||
|
|
||||||
def test_trl_vllm_serve_nccl_only(self):
|
|
||||||
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
|
||||||
|
|
||||||
spec = {
|
|
||||||
"paths": {
|
|
||||||
"/init_communicator/": {"post": {}},
|
|
||||||
"/update_named_param/": {"post": {}},
|
|
||||||
"/batch_update_named_params/": {"post": {}},
|
|
||||||
"/close_communicator/": {"post": {}},
|
|
||||||
"/generate/": {"post": {}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with patch("requests.get") as mock_get:
|
|
||||||
mock_get.return_value.raise_for_status = lambda: None
|
|
||||||
mock_get.return_value.json = lambda: spec
|
|
||||||
caps = probe_vllm_weight_sync("http://localhost:8000")
|
|
||||||
|
|
||||||
assert caps.probed is True
|
|
||||||
assert caps.nccl is True
|
|
||||||
assert caps.lora_filesystem is False
|
|
||||||
assert caps.lora_axolotl is False
|
|
||||||
assert caps.http_full is False
|
|
||||||
|
|
||||||
def test_unreachable_server_records_error(self):
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
|
|
||||||
|
|
||||||
with patch("requests.get") as mock_get:
|
|
||||||
mock_get.side_effect = ConnectionError("Connection refused")
|
|
||||||
caps = probe_vllm_weight_sync("http://localhost:9999")
|
|
||||||
|
|
||||||
assert caps.probed is False
|
|
||||||
assert caps.probe_error is not None
|
|
||||||
assert "ConnectionError" in caps.probe_error
|
|
||||||
assert caps.nccl is False
|
|
||||||
assert caps.lora_filesystem is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestPluginWeightSyncEnforcement(unittest.TestCase):
|
|
||||||
"""End-to-end test of post_trainer_create's transport-selection branch.
|
|
||||||
|
|
||||||
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
|
|
||||||
leaving the trainer learning in isolation while vLLM kept serving the
|
|
||||||
unmodified base model. After the fix:
|
|
||||||
|
|
||||||
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
|
|
||||||
- LoRA + only NCCL endpoint → uses NCCL broadcast
|
|
||||||
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
|
|
||||||
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
|
|
||||||
- No usable transport → raises ValueError with a precise diagnosis
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fake_cfg(adapter, vllm_lora_sync):
|
|
||||||
class FakeTRL:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FakeCfg:
|
|
||||||
pass
|
|
||||||
|
|
||||||
trl = FakeTRL()
|
|
||||||
trl.vllm_lora_sync = vllm_lora_sync
|
|
||||||
trl.vllm_server_host = "127.0.0.1"
|
|
||||||
trl.vllm_server_port = 8000
|
|
||||||
|
|
||||||
cfg = FakeCfg()
|
|
||||||
cfg.nemo_gym_enabled = True
|
|
||||||
cfg.nemo_gym_model_name = None
|
|
||||||
cfg.base_model = "test/model"
|
|
||||||
cfg.nemo_gym_verify_timeout = 30
|
|
||||||
cfg.nemo_gym_multi_turn = True
|
|
||||||
cfg.adapter = adapter
|
|
||||||
cfg.trl = trl
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fake_trainer():
|
|
||||||
class FakeVLLMGen:
|
|
||||||
sync_weights = staticmethod(lambda: None)
|
|
||||||
|
|
||||||
class FakeTrainer:
|
|
||||||
vllm_generation = FakeVLLMGen()
|
|
||||||
|
|
||||||
return FakeTrainer()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _caps(**kwargs):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
|
|
||||||
|
|
||||||
c = VLLMWeightSyncCapabilities(probed=True)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(c, k, v)
|
|
||||||
return c
|
|
||||||
|
|
||||||
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
plugin._vllm_caps = self._caps(lora_filesystem=True)
|
|
||||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(plugin, "_setup_lora_sync") as setup,
|
|
||||||
patch.object(plugin, "_check_lora_endpoint") as check,
|
|
||||||
patch.object(plugin, "_wire_multi_turn") as wire,
|
|
||||||
):
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
setup.assert_called_once()
|
|
||||||
check.assert_called_once()
|
|
||||||
wire.assert_called_once()
|
|
||||||
|
|
||||||
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
plugin._vllm_caps = self._caps() # all False, but probed
|
|
||||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as ctx:
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
msg = str(ctx.exception)
|
|
||||||
assert "no-op trainer" in msg
|
|
||||||
assert "load_lora_adapter" in msg
|
|
||||||
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
|
|
||||||
|
|
||||||
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
plugin._vllm_caps = self._caps(nccl=True)
|
|
||||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
|
|
||||||
with patch.object(plugin, "_wire_multi_turn") as wire:
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
wire.assert_called_once()
|
|
||||||
|
|
||||||
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
plugin._vllm_caps = self._caps(http_full=True)
|
|
||||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
with self.assertRaises(NotImplementedError) as ctx:
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
assert "HTTP weight sync" in str(ctx.exception)
|
|
||||||
|
|
||||||
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
plugin._vllm_caps = self._caps()
|
|
||||||
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
with self.assertRaises(ValueError) as ctx:
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
msg = str(ctx.exception)
|
|
||||||
assert "no-op trainer" in msg
|
|
||||||
assert "init_communicator" in msg
|
|
||||||
assert "http_update_weights" in msg
|
|
||||||
|
|
||||||
def test_unprobed_caps_raises_with_probe_failure_message(self):
|
|
||||||
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
|
|
||||||
|
|
||||||
plugin = NemoGymPlugin()
|
|
||||||
# Plugin._vllm_caps left as default-None: the post_trainer_create
|
|
||||||
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
|
|
||||||
# probed=False, so the error path should mention probing.
|
|
||||||
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
|
|
||||||
trainer = self._fake_trainer()
|
|
||||||
with self.assertRaises(ValueError) as ctx:
|
|
||||||
plugin.post_trainer_create(cfg, trainer)
|
|
||||||
assert "could not probe" in str(ctx.exception)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNemoGymE2E(unittest.TestCase):
|
class TestNemoGymE2E(unittest.TestCase):
|
||||||
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
|
||||||
|
|
||||||
@@ -775,15 +452,19 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
trainer = self._make_mock_trainer()
|
trainer = self._make_mock_trainer()
|
||||||
producer._trainer = trainer
|
producer._trainer = trainer
|
||||||
|
|
||||||
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
# Mock the prompt iterator (returns a batch of 1 input)
|
||||||
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
producer._prompt_iter = iter(
|
||||||
# copies of each unique prompt — one entry per rollout.
|
[
|
||||||
_prompt_batch = [
|
[
|
||||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
{
|
||||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
producer._prompt_dl = [
|
||||||
|
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
||||||
]
|
]
|
||||||
producer._prompt_iter = iter([_prompt_batch])
|
|
||||||
producer._prompt_dl = [_prompt_batch]
|
|
||||||
|
|
||||||
# Call produce
|
# Call produce
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
@@ -849,13 +530,10 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
producer._request_timeout = 30
|
producer._request_timeout = 30
|
||||||
producer._num_generations = 2
|
producer._num_generations = 2
|
||||||
producer._trainer = self._make_mock_trainer()
|
producer._trainer = self._make_mock_trainer()
|
||||||
# RepeatSampler pre-expands by num_generations=2.
|
producer._prompt_iter = iter(
|
||||||
_prompt_batch = [
|
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
)
|
||||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||||
]
|
|
||||||
producer._prompt_iter = iter([_prompt_batch])
|
|
||||||
producer._prompt_dl = [_prompt_batch]
|
|
||||||
|
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
|
|
||||||
|
|||||||
@@ -38,30 +38,6 @@ def _reference_norm_noscale(x, eps):
|
|||||||
return norm(x)
|
return norm(x)
|
||||||
|
|
||||||
|
|
||||||
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
|
|
||||||
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
|
|
||||||
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
|
|
||||||
the trailing columns passed through unchanged. Mirrors how Llama-style
|
|
||||||
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
|
|
||||||
"""
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import (
|
|
||||||
Gemma4RMSNorm,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
)
|
|
||||||
|
|
||||||
D = x.shape[-1]
|
|
||||||
n_rot = cos.shape[-1]
|
|
||||||
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
|
||||||
norm.weight.data.copy_(weight)
|
|
||||||
normed = norm(x)
|
|
||||||
if n_rot == D:
|
|
||||||
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
|
||||||
x_rot = normed[..., :n_rot]
|
|
||||||
x_pass = normed[..., n_rot:]
|
|
||||||
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
|
|
||||||
return torch.cat([rotated, x_pass], dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=[
|
params=[
|
||||||
(2, 64, 32, 256), # sliding window layer shape
|
(2, 64, 32, 256), # sliding window layer shape
|
||||||
@@ -218,172 +194,6 @@ class TestFusedRMSNormRoPEBackward:
|
|||||||
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||||
|
|
||||||
|
|
||||||
class TestFusedRMSNormRoPEPartialRotary:
|
|
||||||
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
|
|
||||||
|
|
||||||
Compares against the original primitives (`Gemma4RMSNorm` +
|
|
||||||
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
|
|
||||||
columns passed through. Without the kernel fix this used to crash with
|
|
||||||
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"B,S,H,D,n_rot",
|
|
||||||
[
|
|
||||||
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
|
|
||||||
(2, 16, 4, 64, 16), # quarter rotary
|
|
||||||
(2, 32, 8, 128, 64), # half rotary, larger heads
|
|
||||||
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
|
|
||||||
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
|
|
||||||
],
|
|
||||||
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
|
|
||||||
)
|
|
||||||
def test_forward_matches_reference(self, B, S, H, D, n_rot):
|
|
||||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
||||||
|
|
||||||
eps = 1e-6
|
|
||||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
|
|
||||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
|
||||||
|
|
||||||
assert y_fused.shape == y_ref.shape == (B, S, H, D)
|
|
||||||
cos_sim = torch.nn.functional.cosine_similarity(
|
|
||||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
||||||
)
|
|
||||||
assert cos_sim > 0.999, (
|
|
||||||
f"partial rotary forward cosine_sim={cos_sim:.6f} "
|
|
||||||
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The pass-through tail must equal the reference RMSNorm output bit-
|
|
||||||
# for-bit (any deviation would mean the kernel is touching it with a
|
|
||||||
# spurious rotation, which is the original bug class).
|
|
||||||
torch.testing.assert_close(
|
|
||||||
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"B,S,H,D,n_rot",
|
|
||||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
|
||||||
ids=["half_64", "quarter_256"],
|
|
||||||
)
|
|
||||||
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
|
|
||||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
||||||
|
|
||||||
eps = 1e-6
|
|
||||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Reference backward via the original primitives
|
|
||||||
x_ref = x_data.clone().requires_grad_(True)
|
|
||||||
w_ref = weight_init.clone()
|
|
||||||
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
|
|
||||||
y_ref.sum().backward()
|
|
||||||
|
|
||||||
# Fused backward
|
|
||||||
x_fused = x_data.clone().requires_grad_(True)
|
|
||||||
w_fused = weight_init.clone().requires_grad_(True)
|
|
||||||
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
|
||||||
y_fused.sum().backward()
|
|
||||||
|
|
||||||
cos_sim_x = torch.nn.functional.cosine_similarity(
|
|
||||||
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
|
||||||
)
|
|
||||||
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"B,S,H,D,n_rot",
|
|
||||||
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
|
||||||
ids=["half_64", "quarter_256"],
|
|
||||||
)
|
|
||||||
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
|
||||||
|
|
||||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
||||||
|
|
||||||
eps = 1e-6
|
|
||||||
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
||||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
|
|
||||||
# rotary applied to the rotated slice.
|
|
||||||
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
|
||||||
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
|
||||||
normed = norm_ref(x_data)
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
|
|
||||||
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
|
|
||||||
y_ref.sum().backward()
|
|
||||||
|
|
||||||
w_fused = weight_init.clone().requires_grad_(True)
|
|
||||||
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
|
||||||
|
|
||||||
cos_sim_w = torch.nn.functional.cosine_similarity(
|
|
||||||
w_fused.grad.flatten().float(),
|
|
||||||
norm_ref.weight.grad.flatten().float(),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
assert cos_sim_w > 0.995, (
|
|
||||||
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
|
|
||||||
"""Regression: passing cos/sin with shape == head_dim must still
|
|
||||||
match the full-rotary reference (the partial-rotary code path must
|
|
||||||
not perturb the existing full-rotary output)."""
|
|
||||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
||||||
|
|
||||||
B, S, H, D = 2, 16, 4, 64
|
|
||||||
eps = 1e-6
|
|
||||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
|
||||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
|
||||||
cos_sim = torch.nn.functional.cosine_similarity(
|
|
||||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
||||||
)
|
|
||||||
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
|
|
||||||
|
|
||||||
def test_validation_errors(self):
|
|
||||||
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
|
|
||||||
Triton crash deeper in the kernel)."""
|
|
||||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
||||||
|
|
||||||
B, S, H, D = 1, 4, 2, 64
|
|
||||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# n_rot > head_dim
|
|
||||||
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
|
||||||
with pytest.raises(ValueError, match="cannot exceed head_dim"):
|
|
||||||
fused_rms_norm_rope(x, w, cos_big, sin_big)
|
|
||||||
|
|
||||||
# cos/sin last-dim mismatch
|
|
||||||
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
|
|
||||||
with pytest.raises(ValueError, match="same last dim"):
|
|
||||||
fused_rms_norm_rope(x, w, cos, sin)
|
|
||||||
|
|
||||||
# odd rotary dim
|
|
||||||
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
|
||||||
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
|
||||||
with pytest.raises(ValueError, match="must be even"):
|
|
||||||
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFusedRMSNormNoScale:
|
class TestFusedRMSNormNoScale:
|
||||||
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,219 +0,0 @@
|
|||||||
"""Tests for the Gemma 4 fused-attention monkey-patch.
|
|
||||||
|
|
||||||
These tests exercise the patched ``Gemma4TextAttention.forward`` against
|
|
||||||
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
|
|
||||||
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
|
|
||||||
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
|
|
||||||
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
|
|
||||||
exercised end-to-end (this is the bug originally documented in
|
|
||||||
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
|
|
||||||
|
|
||||||
The full-model forward also pins that the fused forward keeps accepting
|
|
||||||
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
|
|
||||||
installed transformers version — so any future signature drift on
|
|
||||||
upstream's side trips a clear failure here instead of a confusing
|
|
||||||
TypeError deep in a training run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pytestmark = [
|
|
||||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
|
|
||||||
]
|
|
||||||
|
|
||||||
pytest.importorskip(
|
|
||||||
"transformers.models.gemma4",
|
|
||||||
reason="fused_attn patch only matters when Gemma 4 is available",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def restore_gemma4_attention():
|
|
||||||
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
|
|
||||||
so the monkey-patch does not leak across the suite."""
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
|
||||||
|
|
||||||
saved = Gemma4TextAttention.forward
|
|
||||||
yield Gemma4TextAttention
|
|
||||||
Gemma4TextAttention.forward = saved
|
|
||||||
|
|
||||||
|
|
||||||
def _build_hybrid_config():
|
|
||||||
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
|
|
||||||
layer with proportional rope and partial_rotary_factor=0.25. This is
|
|
||||||
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
|
|
||||||
enough to fit on any GPU."""
|
|
||||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
|
||||||
|
|
||||||
cfg = Gemma4TextConfig(
|
|
||||||
vocab_size=128,
|
|
||||||
hidden_size=64,
|
|
||||||
intermediate_size=128,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
head_dim=32,
|
|
||||||
global_head_dim=64,
|
|
||||||
layer_types=["sliding_attention", "full_attention"],
|
|
||||||
sliding_window=64,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
hidden_size_per_layer_input=16,
|
|
||||||
vocab_size_per_layer_input=128,
|
|
||||||
rope_parameters={
|
|
||||||
"sliding_attention": {
|
|
||||||
"rope_type": "default",
|
|
||||||
"rope_theta": 10000.0,
|
|
||||||
},
|
|
||||||
"full_attention": {
|
|
||||||
"rope_type": "proportional",
|
|
||||||
"rope_theta": 1000000.0,
|
|
||||||
"partial_rotary_factor": 0.25,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
cfg._attn_implementation = "sdpa"
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model(seed=0):
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
cfg = _build_hybrid_config()
|
|
||||||
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
|
|
||||||
|
|
||||||
|
|
||||||
class TestFusedAttnSignature:
|
|
||||||
"""The fused forward must accept the same call shape as
|
|
||||||
``Gemma4TextDecoderLayer`` produces in the installed transformers
|
|
||||||
version. Any signature drift surfaces here as a TypeError."""
|
|
||||||
|
|
||||||
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
|
|
||||||
"""Run a model forward that exercises the real
|
|
||||||
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
|
|
||||||
the fused patch installed."""
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
|
||||||
patch_gemma4_fused_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
m = _build_model()
|
|
||||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
|
||||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
|
||||||
|
|
||||||
patch_gemma4_fused_attn()
|
|
||||||
with torch.no_grad():
|
|
||||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
|
||||||
|
|
||||||
assert out.shape == (2, 16, 64)
|
|
||||||
assert torch.isfinite(out).all()
|
|
||||||
|
|
||||||
|
|
||||||
class TestFusedAttnPerLayerCorrectness:
|
|
||||||
"""Compare the patched attention layer to the stock implementation
|
|
||||||
on a single forward call. This isolates the fused kernel correctness
|
|
||||||
from cross-layer numerical drift."""
|
|
||||||
|
|
||||||
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
|
|
||||||
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
|
|
||||||
installed) for one layer and return the output."""
|
|
||||||
attn = model.layers[layer_idx].self_attn
|
|
||||||
layer_type = model.config.layer_types[layer_idx]
|
|
||||||
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
|
|
||||||
out, _ = attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
position_embeddings=(cos, sin),
|
|
||||||
attention_mask=None,
|
|
||||||
shared_kv_states={},
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"layer_idx",
|
|
||||||
[0, 1],
|
|
||||||
ids=["sliding_head32", "global_head64_proportional"],
|
|
||||||
)
|
|
||||||
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
|
||||||
patch_gemma4_fused_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
m = _build_model(seed=1)
|
|
||||||
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
|
|
||||||
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ref = self._run_attention(m, layer_idx, hs, pos)
|
|
||||||
|
|
||||||
patch_gemma4_fused_attn()
|
|
||||||
with torch.no_grad():
|
|
||||||
got = self._run_attention(m, layer_idx, hs, pos)
|
|
||||||
|
|
||||||
assert got.shape == ref.shape
|
|
||||||
assert torch.isfinite(got).all()
|
|
||||||
cos_sim = torch.nn.functional.cosine_similarity(
|
|
||||||
ref.flatten().float(), got.flatten().float(), dim=0
|
|
||||||
)
|
|
||||||
assert cos_sim > 0.999, (
|
|
||||||
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
|
|
||||||
)
|
|
||||||
# bf16 precision: a few millis of absolute drift per element is
|
|
||||||
# acceptable for a Q/K/V projection pipeline. Anything larger is
|
|
||||||
# a real bug.
|
|
||||||
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFusedAttnFullModel:
|
|
||||||
"""End-to-end model forward + backward through both layer types."""
|
|
||||||
|
|
||||||
def test_full_forward_matches_stock(self, restore_gemma4_attention):
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
|
||||||
patch_gemma4_fused_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
m = _build_model(seed=2)
|
|
||||||
ids = torch.randint(0, 128, (2, 32), device="cuda")
|
|
||||||
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
|
||||||
|
|
||||||
patch_gemma4_fused_attn()
|
|
||||||
with torch.no_grad():
|
|
||||||
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
|
|
||||||
|
|
||||||
assert got.shape == ref.shape
|
|
||||||
assert torch.isfinite(got).all()
|
|
||||||
cos_sim = torch.nn.functional.cosine_similarity(
|
|
||||||
ref.flatten().float(), got.flatten().float(), dim=0
|
|
||||||
)
|
|
||||||
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
|
|
||||||
# accumulates a small amount of numerical drift; we just want to
|
|
||||||
# pin that the two paths are computing the same function.
|
|
||||||
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
|
|
||||||
|
|
||||||
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
|
|
||||||
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
|
|
||||||
for both the sliding and proportional-rope layers."""
|
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
|
||||||
patch_gemma4_fused_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
m = _build_model(seed=3).train()
|
|
||||||
patch_gemma4_fused_attn()
|
|
||||||
|
|
||||||
ids = torch.randint(0, 128, (2, 16), device="cuda")
|
|
||||||
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
|
|
||||||
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
|
|
||||||
out.sum().backward()
|
|
||||||
|
|
||||||
# Both layers must accumulate gradients on q_norm.weight and
|
|
||||||
# k_norm.weight — that proves the fused kernel ran the backward.
|
|
||||||
for i, layer in enumerate(m.layers[:2]):
|
|
||||||
attn = layer.self_attn
|
|
||||||
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
|
|
||||||
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
|
|
||||||
assert attn.q_norm.weight.grad.isfinite().all()
|
|
||||||
assert attn.k_norm.weight.grad.isfinite().all()
|
|
||||||
assert attn.q_norm.weight.grad.abs().sum() > 0
|
|
||||||
assert attn.k_norm.weight.grad.abs().sum() > 0
|
|
||||||
@@ -1,343 +0,0 @@
|
|||||||
"""Tests for the Gemma 4 hybrid-attention mask fix.
|
|
||||||
|
|
||||||
These tests pin the single critical behavior: after installing the patch,
|
|
||||||
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
|
|
||||||
the underlying mask builder regardless of what the caller's config says.
|
|
||||||
This is what keeps full-attention (head_dim=512) global layers from
|
|
||||||
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
|
|
||||||
the 2D FA2 mask that would be built from the model-level config.
|
|
||||||
|
|
||||||
The tests use a mocked ``create_causal_mask`` so they don't have to load
|
|
||||||
a real 26B Gemma 4 model or even have access to its weights. What matters
|
|
||||||
for the bug fix is which config is handed to the mask factory, not the
|
|
||||||
factory's actual output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.importorskip(
|
|
||||||
"transformers.models.gemma4",
|
|
||||||
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def restore_gemma4_module():
|
|
||||||
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
|
|
||||||
each test so patch state doesn't leak across the suite."""
|
|
||||||
from transformers.models.gemma4 import modeling_gemma4
|
|
||||||
|
|
||||||
saved = modeling_gemma4.create_causal_mask
|
|
||||||
yield modeling_gemma4
|
|
||||||
modeling_gemma4.create_causal_mask = saved
|
|
||||||
# Reset the module-level flag so the next test can re-install cleanly.
|
|
||||||
from axolotl.monkeypatch import gemma4_hybrid_mask
|
|
||||||
|
|
||||||
gemma4_hybrid_mask._PATCH_APPLIED = False
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
original = modeling_gemma4.create_causal_mask
|
|
||||||
assert patch_gemma4_hybrid_mask() is True
|
|
||||||
|
|
||||||
assert modeling_gemma4.create_causal_mask is not original
|
|
||||||
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
|
|
||||||
"patched wrapper must expose the original reference for teardown"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_is_idempotent(restore_gemma4_module):
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
wrapper_first = modeling_gemma4.create_causal_mask
|
|
||||||
|
|
||||||
# Second call must not re-wrap the already-wrapped function (which
|
|
||||||
# would leak the original reference through a chain of wrappers).
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
wrapper_second = modeling_gemma4.create_causal_mask
|
|
||||||
|
|
||||||
assert wrapper_first is wrapper_second
|
|
||||||
|
|
||||||
|
|
||||||
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
|
|
||||||
"""Core invariant: when the patched wrapper is called with a config
|
|
||||||
that says ``flash_attention_2``, the underlying mask factory receives
|
|
||||||
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
|
|
||||||
|
|
||||||
Without this, the full-attention global layers get a 2D FA2 mask and
|
|
||||||
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
|
|
||||||
"""
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
# Swap in a mock BEFORE installing the patch so the wrapper captures
|
|
||||||
# it as the "original". The mock records every call so we can inspect
|
|
||||||
# what config got passed through.
|
|
||||||
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
|
|
||||||
modeling_gemma4.create_causal_mask = mock_factory
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
|
|
||||||
# Caller-supplied config says FA2 (that's the model-level setting).
|
|
||||||
caller_config = SimpleNamespace(
|
|
||||||
_attn_implementation="flash_attention_2",
|
|
||||||
head_dim=512,
|
|
||||||
some_other_attr="preserved",
|
|
||||||
)
|
|
||||||
result = modeling_gemma4.create_causal_mask(
|
|
||||||
caller_config,
|
|
||||||
inputs_embeds=None,
|
|
||||||
attention_mask=None,
|
|
||||||
past_key_values=None,
|
|
||||||
position_ids=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrapper returned whatever the mock returned — no transformation of
|
|
||||||
# the result itself.
|
|
||||||
assert result == "mask_4d"
|
|
||||||
|
|
||||||
# The mock was called exactly once with a config whose
|
|
||||||
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
|
|
||||||
assert mock_factory.call_count == 1
|
|
||||||
(passed_config, *_), passed_kwargs = mock_factory.call_args
|
|
||||||
assert passed_config._attn_implementation == "sdpa"
|
|
||||||
|
|
||||||
# The wrapper must NOT mutate the caller's config in place — other
|
|
||||||
# mask builders (e.g. create_sliding_window_causal_mask) read from
|
|
||||||
# the same config and must still see fa2.
|
|
||||||
assert caller_config._attn_implementation == "flash_attention_2"
|
|
||||||
|
|
||||||
# Other attributes on the config must be preserved so the underlying
|
|
||||||
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
|
|
||||||
assert passed_config.head_dim == 512
|
|
||||||
assert passed_config.some_other_attr == "preserved"
|
|
||||||
|
|
||||||
|
|
||||||
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
|
|
||||||
"""The wrapper must forward positional + keyword args to the original
|
|
||||||
unchanged, so transformers' own call-site in Gemma4TextModel.forward
|
|
||||||
keeps working across minor transformers-version signature drift."""
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
mock_factory = MagicMock(return_value="mask")
|
|
||||||
modeling_gemma4.create_causal_mask = mock_factory
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
|
|
||||||
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
|
|
||||||
modeling_gemma4.create_causal_mask(
|
|
||||||
caller_config,
|
|
||||||
"positional_arg",
|
|
||||||
inputs_embeds="embeds",
|
|
||||||
attention_mask="mask_2d",
|
|
||||||
past_key_values="cache",
|
|
||||||
position_ids="positions",
|
|
||||||
or_mask_function="or_fn",
|
|
||||||
)
|
|
||||||
|
|
||||||
args, kwargs = mock_factory.call_args
|
|
||||||
# First positional (after config override) is preserved.
|
|
||||||
assert args[1] == "positional_arg"
|
|
||||||
# All kwargs are forwarded untouched.
|
|
||||||
assert kwargs["inputs_embeds"] == "embeds"
|
|
||||||
assert kwargs["attention_mask"] == "mask_2d"
|
|
||||||
assert kwargs["past_key_values"] == "cache"
|
|
||||||
assert kwargs["position_ids"] == "positions"
|
|
||||||
assert kwargs["or_mask_function"] == "or_fn"
|
|
||||||
|
|
||||||
|
|
||||||
def test_unpatch_restores_original(restore_gemma4_module):
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import (
|
|
||||||
patch_gemma4_hybrid_mask,
|
|
||||||
unpatch_gemma4_hybrid_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
sentinel = MagicMock(name="original")
|
|
||||||
modeling_gemma4.create_causal_mask = sentinel
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
assert modeling_gemma4.create_causal_mask is not sentinel
|
|
||||||
|
|
||||||
unpatch_gemma4_hybrid_mask()
|
|
||||||
assert modeling_gemma4.create_causal_mask is sentinel
|
|
||||||
|
|
||||||
|
|
||||||
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
# Should be a no-op, no exception.
|
|
||||||
unpatch_gemma4_hybrid_mask()
|
|
||||||
|
|
||||||
|
|
||||||
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
|
|
||||||
"""Only ``create_causal_mask`` is overridden — the sliding-window
|
|
||||||
factory must remain bound to its original to preserve FA2 masks for
|
|
||||||
the sliding-attention layers. If we accidentally patch both, the
|
|
||||||
sliding layers get SDPA format and lose the FA2 speedup."""
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
|
|
||||||
pytest.skip("transformers version has no create_sliding_window_causal_mask")
|
|
||||||
|
|
||||||
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
|
|
||||||
assert sliding_after is sliding_before
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
|
|
||||||
#
|
|
||||||
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
|
|
||||||
# model with 2 layers (one sliding, one full_attention), apply the hybrid
|
|
||||||
# attention path end-to-end, and run a forward pass with a padded
|
|
||||||
# attention_mask at a long-ish seq len. The invariant we're pinning is that
|
|
||||||
# the full_attention layer does not crash with the
|
|
||||||
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
|
|
||||||
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
|
|
||||||
# tokens in the FSDP2 training run.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _build_tiny_gemma4_text_model():
|
|
||||||
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
|
|
||||||
import torch
|
|
||||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
|
||||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
|
|
||||||
|
|
||||||
cfg = Gemma4TextConfig(
|
|
||||||
vocab_size=128,
|
|
||||||
hidden_size=64,
|
|
||||||
intermediate_size=128,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
head_dim=32,
|
|
||||||
layer_types=["sliding_attention", "full_attention"],
|
|
||||||
sliding_window=64,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
hidden_size_per_layer_input=16,
|
|
||||||
vocab_size_per_layer_input=128,
|
|
||||||
)
|
|
||||||
# Caller-supplied attn impl simulates the pilot config (fa2 at model
|
|
||||||
# level). The hybrid patch is what makes this survive long context.
|
|
||||||
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
|
|
||||||
torch.manual_seed(42)
|
|
||||||
model = Gemma4TextModel(cfg).eval()
|
|
||||||
return model, cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_hybrid_attn_inline(model, cfg):
|
|
||||||
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
|
|
||||||
to a model, without needing a full PatchManager / pydantic cfg."""
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
for layer_idx, layer in enumerate(model.layers):
|
|
||||||
if cfg.layer_types[layer_idx] != "sliding_attention":
|
|
||||||
attn = getattr(layer, "self_attn", None)
|
|
||||||
if attn is not None and hasattr(attn, "config"):
|
|
||||||
sdpa_cfg = copy.copy(attn.config)
|
|
||||||
sdpa_cfg._attn_implementation = "sdpa"
|
|
||||||
attn.config = sdpa_cfg
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
|
|
||||||
|
|
||||||
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
|
|
||||||
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
|
|
||||||
Gemma4TextModel runs a forward at long context (1024 tokens) with
|
|
||||||
real padding in the attention mask, producing the expected output
|
|
||||||
shape. This exercises the actual code path that crashed the pilot
|
|
||||||
without needing a real 26B checkpoint or CUDA."""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
model, cfg = _build_tiny_gemma4_text_model()
|
|
||||||
_apply_hybrid_attn_inline(model, cfg)
|
|
||||||
|
|
||||||
B, S = 2, 1024
|
|
||||||
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
|
|
||||||
attn_mask = torch.ones(B, S, dtype=torch.long)
|
|
||||||
# Pad positions in the second row. Without padding, SDPA falls back to
|
|
||||||
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
|
|
||||||
# mask to exercise the actual bug site.
|
|
||||||
attn_mask[1, S // 2 :] = 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
|
||||||
|
|
||||||
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
|
|
||||||
assert torch.isfinite(out.last_hidden_state).all()
|
|
||||||
|
|
||||||
|
|
||||||
def test_patched_create_causal_mask_returns_4d_for_real_config(
|
|
||||||
restore_gemma4_module,
|
|
||||||
):
|
|
||||||
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
|
|
||||||
and verify the returned mask is a 4D tensor — which is the shape the
|
|
||||||
SDPA-patched global layers need. Without the patch and with a
|
|
||||||
caller-supplied FA2 config this would return a 2D mask and the layer
|
|
||||||
would crash at long context."""
|
|
||||||
import torch
|
|
||||||
from transformers.cache_utils import DynamicCache
|
|
||||||
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
|
||||||
|
|
||||||
patch_gemma4_hybrid_mask()
|
|
||||||
modeling_gemma4 = restore_gemma4_module
|
|
||||||
|
|
||||||
cfg = Gemma4TextConfig(
|
|
||||||
vocab_size=128,
|
|
||||||
hidden_size=64,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
head_dim=32,
|
|
||||||
layer_types=["sliding_attention", "full_attention"],
|
|
||||||
sliding_window=64,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
hidden_size_per_layer_input=16,
|
|
||||||
vocab_size_per_layer_input=128,
|
|
||||||
)
|
|
||||||
# Simulate the pilot: caller says flash_attention_2, but global layers
|
|
||||||
# were switched to SDPA per-layer. Without the patch, create_causal_mask
|
|
||||||
# would return an FA2 2D mask here and the SDPA layer would crash.
|
|
||||||
cfg._attn_implementation = "flash_attention_2"
|
|
||||||
|
|
||||||
B, S = 2, 1024
|
|
||||||
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
|
|
||||||
attention_mask = torch.ones((B, S), dtype=torch.long)
|
|
||||||
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
|
|
||||||
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
|
|
||||||
past_key_values = DynamicCache(config=cfg)
|
|
||||||
|
|
||||||
mask = modeling_gemma4.create_causal_mask(
|
|
||||||
config=cfg,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mask is not None
|
|
||||||
assert isinstance(mask, torch.Tensor)
|
|
||||||
assert mask.dim() == 4, (
|
|
||||||
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
|
|
||||||
f"shape={tuple(mask.shape)}. The full_attention global layers need "
|
|
||||||
"this shape or they crash at long context."
|
|
||||||
)
|
|
||||||
assert mask.shape[0] == B
|
|
||||||
assert mask.shape[-1] == S
|
|
||||||
assert mask.shape[-2] == S
|
|
||||||
|
|
||||||
# Caller's config must be untouched — other code paths still read it.
|
|
||||||
assert cfg._attn_implementation == "flash_attention_2"
|
|
||||||
@@ -487,70 +487,3 @@ class TestDatasetPreparation:
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_load_dataset_with_str_json_data(self, tokenizer):
|
|
||||||
"""
|
|
||||||
Test loading datasets where data is stored as str JSON instead of list of dicts.
|
|
||||||
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
|
|
||||||
"""
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
import json
|
|
||||||
|
|
||||||
str_json_ds = Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"messages": json.dumps(
|
|
||||||
[
|
|
||||||
{"role": "user", "content": "Hello how are you?"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "I am doing good thanks",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"messages": json.dumps(
|
|
||||||
[
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "2+2 equals 4."},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
|
|
||||||
str_json_ds.to_parquet(tmp_ds_path)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": str(tmp_ds_path),
|
|
||||||
"name": "test_str_json",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "messages",
|
|
||||||
"message_field_role": "role",
|
|
||||||
"message_field_content": "content",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"dataset_num_proc": 4,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
|
||||||
):
|
|
||||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
|
||||||
|
|
||||||
assert len(dataset) == 2
|
|
||||||
assert "input_ids" in dataset.features
|
|
||||||
assert "attention_mask" in dataset.features
|
|
||||||
assert "labels" in dataset.features
|
|
||||||
|
|
||||||
assert len(dataset[0]["input_ids"]) > 0
|
|
||||||
|
|||||||
@@ -133,108 +133,3 @@ class TestRevisionParameter:
|
|||||||
|
|
||||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
assert "revision" not in call_kwargs.kwargs
|
assert "revision" not in call_kwargs.kwargs
|
||||||
|
|
||||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
||||||
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
|
||||||
mock_processor = MagicMock()
|
|
||||||
mock_processor.size = {}
|
|
||||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"processor_config": "some-model",
|
|
||||||
"trust_remote_code": False,
|
|
||||||
"processor_kwargs": {
|
|
||||||
"image_seq_length": 1120,
|
|
||||||
"max_soft_tokens": 1120,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
||||||
|
|
||||||
from axolotl.loaders.processor import load_processor
|
|
||||||
|
|
||||||
load_processor(cfg, tokenizer)
|
|
||||||
|
|
||||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
||||||
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
|
||||||
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
|
||||||
|
|
||||||
@patch("axolotl.loaders.processor.AutoProcessor")
|
|
||||||
def test_load_processor_omits_processor_kwargs_when_unset(
|
|
||||||
self, mock_auto_processor
|
|
||||||
):
|
|
||||||
mock_processor = MagicMock()
|
|
||||||
mock_processor.size = {}
|
|
||||||
mock_auto_processor.from_pretrained.return_value = mock_processor
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"processor_config": "some-model",
|
|
||||||
"trust_remote_code": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
|
||||||
|
|
||||||
from axolotl.loaders.processor import load_processor
|
|
||||||
|
|
||||||
load_processor(cfg, tokenizer)
|
|
||||||
|
|
||||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
|
||||||
assert "image_seq_length" not in call_kwargs.kwargs
|
|
||||||
assert "max_soft_tokens" not in call_kwargs.kwargs
|
|
||||||
|
|
||||||
def test_processor_kwargs_schema_rejects_revision(self):
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.model import ModelInputConfig
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="revision"):
|
|
||||||
ModelInputConfig(
|
|
||||||
base_model="some-model",
|
|
||||||
processor_kwargs={"revision": "abc123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.model import ModelInputConfig
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="trust_remote_code"):
|
|
||||||
ModelInputConfig(
|
|
||||||
base_model="some-model",
|
|
||||||
processor_kwargs={"trust_remote_code": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
|
||||||
from axolotl.utils.schemas.model import ModelInputConfig
|
|
||||||
|
|
||||||
cfg = ModelInputConfig(
|
|
||||||
base_model="some-model",
|
|
||||||
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
|
||||||
)
|
|
||||||
assert cfg.processor_kwargs == {
|
|
||||||
"image_seq_length": 1120,
|
|
||||||
"max_soft_tokens": 1120,
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
|
||||||
from axolotl.utils.schemas.model import ModelInputConfig
|
|
||||||
|
|
||||||
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
|
||||||
assert (
|
|
||||||
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
cfg = min_base_cfg | DictDefault(
|
|
||||||
tokenizer_use_mistral_common=True,
|
|
||||||
processor_kwargs={"image_seq_length": 1120},
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError, match="processor_kwargs"):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ Covers:
|
|||||||
- save_strategy: 'best' requires metric_for_best_model
|
- save_strategy: 'best' requires metric_for_best_model
|
||||||
- streaming=True with val_set_size > 0 is rejected
|
- streaming=True with val_set_size > 0 is rejected
|
||||||
- lora_target_modules with invalid regex patterns is rejected
|
- lora_target_modules with invalid regex patterns is rejected
|
||||||
- GRPO: generation batch size must be divisible by num_generations,
|
|
||||||
num_generations >= 2, and effective_gbs >= num_generations * world_size
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -119,136 +117,3 @@ class TestLoraTargetModulesRegexValidator:
|
|||||||
)
|
)
|
||||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
class TestGRPOBatchSizeValidator:
|
|
||||||
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
|
|
||||||
|
|
||||||
These call the @model_validator(mode="before") classmethod directly on a
|
|
||||||
plain dict — same input shape it receives during full Pydantic validation,
|
|
||||||
just without dragging in unrelated fields (datasets / model loading / etc.)
|
|
||||||
that aren't relevant to what's under test. The validator is registered on
|
|
||||||
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
|
|
||||||
same code path ``axolotl train`` exercises.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check(data):
|
|
||||||
from axolotl.utils.schemas.validation import RLValidationMixin
|
|
||||||
|
|
||||||
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
|
|
||||||
|
|
||||||
def test_divisible_passes(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
# Should return data unchanged (no exception)
|
|
||||||
out = self._check(data)
|
|
||||||
assert out["trl"]["num_generations"] == 4
|
|
||||||
|
|
||||||
def test_non_divisible_raises(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValueError, match="num_generations"):
|
|
||||||
self._check(data)
|
|
||||||
|
|
||||||
def test_non_divisible_error_includes_fix_hint(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 3,
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
|
|
||||||
self._check(data)
|
|
||||||
|
|
||||||
def test_num_generations_one_raises(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"trl": {"num_generations": 1},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValueError, match=r"num_generations >= 2"):
|
|
||||||
self._check(data)
|
|
||||||
|
|
||||||
def test_explicit_generation_batch_size_divisible_passes(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"trl": {"num_generations": 4, "generation_batch_size": 8},
|
|
||||||
}
|
|
||||||
out = self._check(data)
|
|
||||||
assert out["trl"]["generation_batch_size"] == 8
|
|
||||||
|
|
||||||
def test_explicit_generation_batch_size_non_divisible_raises(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"trl": {"num_generations": 4, "generation_batch_size": 6},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValueError, match="trl.generation_batch_size"):
|
|
||||||
self._check(data)
|
|
||||||
|
|
||||||
def test_non_grpo_skips_check(self):
|
|
||||||
# Anything other than rl=grpo should pass through untouched, even
|
|
||||||
# with non-divisible batch sizes — they're irrelevant to other RL
|
|
||||||
# methods that don't use group-relative advantages.
|
|
||||||
data = {
|
|
||||||
"rl": "dpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 3,
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
assert self._check(data) is data
|
|
||||||
|
|
||||||
def test_no_rl_set_skips_check(self):
|
|
||||||
data = {
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 3,
|
|
||||||
}
|
|
||||||
assert self._check(data) is data
|
|
||||||
|
|
||||||
def test_grpo_without_num_generations_skips_check(self):
|
|
||||||
# If num_generations isn't set, TRL uses its own default — we don't
|
|
||||||
# have enough info to validate, so the validator must short-circuit
|
|
||||||
# rather than guess.
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 3,
|
|
||||||
"trl": {},
|
|
||||||
}
|
|
||||||
out = self._check(data)
|
|
||||||
assert out["rl"] == "grpo"
|
|
||||||
|
|
||||||
def test_multi_rank_group_size_check(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4, # gbs=4
|
|
||||||
"world_size": 2, # need gbs >= 4*2 = 8
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
with pytest.raises(ValueError, match=r"world_size=2"):
|
|
||||||
self._check(data)
|
|
||||||
|
|
||||||
def test_multi_rank_group_size_satisfied(self):
|
|
||||||
data = {
|
|
||||||
"rl": "grpo",
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
|
|
||||||
"world_size": 2,
|
|
||||||
"trl": {"num_generations": 4},
|
|
||||||
}
|
|
||||||
out = self._check(data)
|
|
||||||
assert out["gradient_accumulation_steps"] == 8
|
|
||||||
|
|||||||
Reference in New Issue
Block a user