Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
993db05b3a | ||
|
|
1b9520cc8b | ||
|
|
f77408a3d0 | ||
|
|
5db4272f69 | ||
|
|
431888c1de | ||
|
|
1bf65c500e | ||
|
|
bcbe049c21 | ||
|
|
90090fa9e8 | ||
|
|
7420fd4de6 | ||
|
|
05113bc91a | ||
|
|
e562e149ce | ||
|
|
9de5b76336 |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,7 +31,10 @@ PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
# Install axolotl + dev and test dependencies from lockfile
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/*.yml'
|
||||
- "*.[q]md"
|
||||
- "examples/**/*.y[a]?ml"
|
||||
|
||||
35
.github/workflows/multi-gpu-e2e.yml
vendored
35
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
- "tests/e2e/multigpu/**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/multi-gpu-e2e.yml"
|
||||
- "scripts/cutcrossentropy_install.py"
|
||||
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
|
||||
- "src/axolotl/utils/distributed.py"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
@@ -33,19 +31,19 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -53,7 +51,6 @@ jobs:
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -75,7 +72,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -41,11 +44,15 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install wheel packaging
|
||||
uv pip install --no-build-isolation -e .
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
|
||||
55
.github/workflows/tests-nightly.yml
vendored
55
.github/workflows/tests-nightly.yml
vendored
@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
- ".github/workflows/tests-nightly.yml"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
@@ -20,7 +23,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -43,7 +46,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
@@ -61,36 +64,34 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Override with nightly HF packages
|
||||
run: |
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -102,9 +103,6 @@ jobs:
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
@@ -136,7 +134,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -157,7 +154,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
|
||||
85
.github/workflows/tests.yml
vendored
85
.github/workflows/tests.yml
vendored
@@ -6,21 +6,19 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
@@ -33,6 +31,7 @@ permissions:
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -44,7 +43,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -94,32 +93,25 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -188,33 +180,27 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
uv pip install packaging setuptools_scm build wheel psutil
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -291,7 +277,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -312,7 +297,7 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
@@ -374,7 +359,7 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
|
||||
| Method | Config Key | When to Use |
|
||||
|--------|-----------|-------------|
|
||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include VERSION
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
include AGENTS.md
|
||||
recursive-include docs/agents *.md
|
||||
|
||||
26
README.md
26
README.md
@@ -95,14 +95,11 @@ Features:
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using uv (recommended)
|
||||
|
||||
```bash
|
||||
# install uv if you don't already have it installed
|
||||
# install uv if you don't already have it installed (restart shell after)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# CUDA 12.8.1 tends to have better package compatibility
|
||||
# change depending on system
|
||||
export UV_TORCH_BACKEND=cu128
|
||||
|
||||
# create a new virtual environment
|
||||
@@ -112,23 +109,6 @@ source .venv/bin/activate
|
||||
uv pip install torch==2.10.0 torchvision
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
|
||||
# recommended - install cut-cross-entropy
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
|
||||
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
||||
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
@@ -138,7 +118,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
@@ -134,7 +134,6 @@ quartodoc:
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||
@@ -327,7 +326,6 @@ website:
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py --uv | sh
|
||||
# Override with nightly HF packages for nightly builds
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
|
||||
fi
|
||||
|
||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
|
||||
|
||||
set -o pipefail
|
||||
for i in 1 2 3; do
|
||||
|
||||
@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
fi && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -33,7 +33,6 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
||||
|
||||
1. Paired preference data (chosen + rejected)?
|
||||
- Default → `rl: dpo`
|
||||
- Overfitting → `rl: ipo`
|
||||
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
||||
- VRAM-limited → `rl: orpo` (no ref model)
|
||||
- Length-sensitive → `rl: simpo` (no ref model)
|
||||
2. Only binary labels (good/bad)? → `rl: kto`
|
||||
|
||||
@@ -76,8 +76,9 @@ datasets:
|
||||
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -208,17 +209,17 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -6,23 +6,30 @@ format:
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
This section describes the different Docker images that are released by AxolotlAI at
|
||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
||||
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
||||
:::
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
|
||||
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-base
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
|
||||
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
@@ -32,8 +39,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
- `main-base-py3.12-cu128-2.10.0`
|
||||
- `main-base-py3.12-cu130-2.9.1`
|
||||
- `main-base-py3.12-cu130-2.10.0`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -41,11 +50,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
|
||||
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
|
||||
|
||||
#### Tags format {#sec-main-tags}
|
||||
|
||||
@@ -53,7 +61,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
# on push to main
|
||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
||||
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
|
||||
main-latest
|
||||
|
||||
# nightly build
|
||||
@@ -71,11 +79,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-py3.12-cu128-2.10.0`
|
||||
- `main-py3.12-cu130-2.9.1`
|
||||
- `main-py3.12-cu130-2.10.0`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `main-20260315-py3.11-cu128-2.9.1`
|
||||
- `0.12.0`
|
||||
|
||||
## Cloud
|
||||
@@ -90,11 +99,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-cloud
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
|
||||
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
|
||||
@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.9.0
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
## Installation {#sec-installation}
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
### Quick Install {#sec-uv}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
|
||||
|
||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
||||
installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### uv Installation {#sec-uv}
|
||||
|
||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
||||
|
||||
Install uv if not already installed
|
||||
Install uv if not already installed:
|
||||
```{.bash}
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
```
|
||||
|
||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
||||
then create the venv and activate
|
||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu126
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Install PyTorch
|
||||
- PyTorch 2.6.0 recommended
|
||||
```{.bash}
|
||||
uv pip install packaging setuptools wheel
|
||||
uv pip install torch==2.6.0
|
||||
uv pip install awscli pydantic
|
||||
```
|
||||
|
||||
Install axolotl from PyPi
|
||||
```{.bash}
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
||||
|
||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
@@ -82,14 +48,17 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
||||
|
||||
### Docker {#sec-docker}
|
||||
|
||||
```{.bash}
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
For development with Docker:
|
||||
@@ -106,12 +75,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||
axolotlai/axolotl:main-latest
|
||||
axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
@@ -122,7 +91,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
|
||||
|
||||
For providers supporting Docker:
|
||||
|
||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||
- Use `axolotlai/axolotl-cloud-uv:main-latest`
|
||||
- Available on:
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||
@@ -141,7 +110,7 @@ For providers supporting Docker:
|
||||
### macOS {#sec-macos}
|
||||
|
||||
```{.bash}
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
```
|
||||
|
||||
See @sec-troubleshooting for Mac-specific issues.
|
||||
@@ -152,21 +121,44 @@ See @sec-troubleshooting for Mac-specific issues.
|
||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
:::
|
||||
|
||||
## Environment Managers {#sec-env-managers}
|
||||
## Migrating from pip to uv {#sec-migrating}
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
If you have an existing pip-based Axolotl installation, you can migrate to uv:
|
||||
|
||||
1. Install Python ≥3.11
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
hf auth login
|
||||
```
|
||||
```{.bash}
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# Create a fresh venv (recommended for a clean start)
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
|
||||
# Reinstall axolotl
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
## Using pip (Alternative) {#sec-pip}
|
||||
|
||||
If you are unable to install uv, you can still use pip directly.
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have PyTorch installed before installing Axolotl with pip.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
For editable/development installs:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
|
||||
@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
rl: dpo
|
||||
dpo_loss_type: ["ipo"]
|
||||
```
|
||||
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
||||
|
||||
### ORPO
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
---
|
||||
title: "Unsloth"
|
||||
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||
---
|
||||
|
||||
### Overview
|
||||
|
||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||
standard industry baselines.
|
||||
|
||||
::: {.callout-important}
|
||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
||||
|
||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
||||
:::
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install the correct unsloth and extras from source.
|
||||
|
||||
```bash
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||
|
||||
Our unsloth integration is currently limited to the following model architectures:
|
||||
- llama
|
||||
|
||||
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||
```yaml
|
||||
unsloth_lora_mlp: true
|
||||
unsloth_lora_qkv: true
|
||||
unsloth_lora_o: true
|
||||
```
|
||||
|
||||
These options are composable and can be used with multi-gpu finetuning
|
||||
```yaml
|
||||
unsloth_cross_entropy_loss: true
|
||||
unsloth_rms_norm: true
|
||||
unsloth_rope: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- Single GPU only; e.g. no multi-gpu support
|
||||
- No deepspeed or FSDP support (requires multi-gpu)
|
||||
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||
- No MoE support.
|
||||
@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
uv pip uninstall causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
uv pip install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -17,8 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -16,8 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -10,17 +10,16 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
```bash
|
||||
pip3 install timm==1.0.17
|
||||
uv pip install timm==1.0.17
|
||||
|
||||
# for loading audio data
|
||||
pip3 install librosa==0.11.0
|
||||
uv pip install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Download sample dataset files
|
||||
|
||||
@@ -14,8 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
@@ -87,7 +86,7 @@ for more information about using a special vllm-openai docker image for inferenc
|
||||
Optionally, vLLM can be installed from nightly:
|
||||
|
||||
```bash
|
||||
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
```
|
||||
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
||||
```bash
|
||||
|
||||
@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,8 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl.
|
||||
2. Install `timm` for vision model support:
|
||||
|
||||
```bash
|
||||
pip install timm==1.0.19
|
||||
uv pip install timm==1.0.19
|
||||
```
|
||||
|
||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
@@ -14,8 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -23,7 +23,7 @@ Note: This is still experimental given it is based on transformers v5 RC.
|
||||
git checkout transformers-v5
|
||||
|
||||
# Install packages for transformers v5
|
||||
pip install -e .
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
4. Run the fine-tuning:
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.6'
|
||||
uv pip install 'mistral-common[opencv]==1.8.6'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
3. Install transformers from main
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
uv pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
@@ -12,7 +12,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,14 +13,13 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
```bash
|
||||
pip3 install num2words==0.5.14
|
||||
uv pip install num2words==0.5.14
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
@@ -12,16 +12,15 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
```bash
|
||||
# audio
|
||||
pip3 install librosa==0.11.0
|
||||
pip3 install 'mistral_common[audio]==1.8.3'
|
||||
uv pip install librosa==0.11.0
|
||||
uv pip install 'mistral_common[audio]==1.8.3'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
203
pyproject.toml
203
pyproject.toml
@@ -1,15 +1,165 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "axolotl"
|
||||
dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||
dynamic = ["version"]
|
||||
description = "LLM Trainer"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
# license = "Apache-2.0"
|
||||
|
||||
dependencies = [
|
||||
# Core ML stack
|
||||
"torch>=2.6.0",
|
||||
"packaging==26.0",
|
||||
"huggingface_hub>=1.1.7",
|
||||
"peft>=0.19.1,<0.20.0",
|
||||
"tokenizers>=0.22.1",
|
||||
"transformers==5.5.4",
|
||||
"accelerate==1.13.0",
|
||||
"datasets>=4.8.4,<4.9.0",
|
||||
"trl==1.1.0",
|
||||
"hf_xet==1.4.3",
|
||||
"kernels==0.13.0",
|
||||
"trackio>=0.16.1",
|
||||
"typing-extensions>=4.15.0",
|
||||
"optimum==1.16.2",
|
||||
"hf_transfer",
|
||||
"sentencepiece",
|
||||
"gradio>=6.2.0,<7.0",
|
||||
"modal==1.3.0.post1",
|
||||
"pydantic>=2.10.6",
|
||||
"addict",
|
||||
"fire",
|
||||
"PyYAML>=6.0",
|
||||
"requests",
|
||||
"wandb",
|
||||
"einops",
|
||||
"colorama",
|
||||
"numba>=0.61.2",
|
||||
"numpy>=2.2.6",
|
||||
|
||||
# Evaluation & metrics
|
||||
"evaluate==0.4.1",
|
||||
"scipy",
|
||||
"nvidia-ml-py==12.560.30",
|
||||
"art",
|
||||
"tensorboard",
|
||||
"python-dotenv==1.0.1",
|
||||
|
||||
# Remote filesystems
|
||||
"s3fs>=2024.5.0",
|
||||
"gcsfs>=2025.3.0",
|
||||
"adlfs>=2024.5.0",
|
||||
"ocifs==1.3.2",
|
||||
|
||||
"zstandard==0.22.0",
|
||||
"fastcore",
|
||||
|
||||
# lm eval harness
|
||||
"lm_eval==0.4.11",
|
||||
"langdetect==1.0.9",
|
||||
"immutabledict==4.2.0",
|
||||
"antlr4-python3-runtime==4.13.2",
|
||||
|
||||
"schedulefree==1.4.1",
|
||||
"openenv-core==0.1.0",
|
||||
|
||||
# Axolotl contribs
|
||||
"axolotl-contribs-lgpl==0.0.7",
|
||||
"axolotl-contribs-mit==0.0.6",
|
||||
|
||||
# Telemetry
|
||||
"posthog==6.7.11",
|
||||
|
||||
"mistral-common==1.11.0",
|
||||
|
||||
# Platform-specific (Linux only)
|
||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||
|
||||
# Architecture-specific
|
||||
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
|
||||
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
flash-attn = ["flash-attn==2.8.3"]
|
||||
ring-flash-attn = [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
]
|
||||
deepspeed = [
|
||||
"deepspeed>=0.18.6,<0.19.0",
|
||||
"deepspeed-kernels",
|
||||
]
|
||||
mamba-ssm = [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
]
|
||||
auto-gptq = [
|
||||
"auto-gptq==0.5.1",
|
||||
]
|
||||
mlflow = [
|
||||
"mlflow",
|
||||
]
|
||||
galore = [
|
||||
"galore_torch",
|
||||
]
|
||||
apollo = [
|
||||
"apollo-torch",
|
||||
]
|
||||
optimizers = [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
]
|
||||
ray = [
|
||||
"ray[train]>=2.52.1",
|
||||
]
|
||||
vllm = [
|
||||
"vllm>=0.15.0",
|
||||
]
|
||||
llmcompressor = [
|
||||
"llmcompressor>=0.10.0",
|
||||
]
|
||||
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
|
||||
opentelemetry = [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"black",
|
||||
"mypy",
|
||||
"pre-commit",
|
||||
"types-requests",
|
||||
"quartodoc",
|
||||
"jupyter",
|
||||
"blobfile",
|
||||
"tiktoken",
|
||||
]
|
||||
test = [
|
||||
"codecov",
|
||||
"codecov-cli",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-retry",
|
||||
"pytest-sugar",
|
||||
"pytest-xdist",
|
||||
"tbparse",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
axolotl = "axolotl.cli.main:main"
|
||||
|
||||
@@ -18,18 +168,15 @@ Homepage = "https://axolotl.ai/"
|
||||
Documentation = "https://docs.axolotl.ai/"
|
||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { file = "VERSION" }
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py310"
|
||||
@@ -67,5 +214,43 @@ markers = [
|
||||
"slow: marks tests as slow",
|
||||
]
|
||||
|
||||
# UV specific configuration
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
conflicts = [
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "vllm" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "flash-attn" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "ring-flash-attn" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "mamba-ssm" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "auto-gptq" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "fbgemm-gpu" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "llmcompressor" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
mamba-ssm = [{ requirement = "torch", match-runtime = true }]
|
||||
causal-conv1d = [{ requirement = "torch", match-runtime = true }]
|
||||
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||
auto-gptq = [{ requirement = "torch", match-runtime = true }]
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
black
|
||||
mypy
|
||||
pre-commit
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
blobfile
|
||||
tiktoken
|
||||
@@ -1,8 +0,0 @@
|
||||
codecov
|
||||
codecov-cli
|
||||
pytest
|
||||
pytest-cov
|
||||
pytest-retry
|
||||
pytest-sugar
|
||||
pytest-xdist
|
||||
tbparse
|
||||
@@ -1,78 +0,0 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.49.1
|
||||
triton>=3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.7.0
|
||||
# END section
|
||||
|
||||
packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.19.0,<0.20.0
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.5.4
|
||||
accelerate==1.13.0
|
||||
datasets>=4.8.4,<4.9.0
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
trl==1.1.0
|
||||
hf_xet==1.4.3
|
||||
kernels==0.13.0
|
||||
|
||||
fla-core==0.4.1
|
||||
flash-linear-attention==0.4.1
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio>=6.2.0,<7.0
|
||||
|
||||
modal==1.3.0.post1
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
wandb
|
||||
einops
|
||||
colorama
|
||||
numba>=0.61.2
|
||||
numpy>=2.2.6
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
nvidia-ml-py==12.560.30
|
||||
art
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2025.3.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.11
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.17.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.11.0
|
||||
@@ -1,40 +0,0 @@
|
||||
# noqa
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as error:
|
||||
raise ImportError("Install torch via `pip install torch`") from error
|
||||
from packaging.version import Version as V
|
||||
|
||||
use_uv = "--uv" in sys.argv[1:]
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
try:
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere = False
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
raise RuntimeError(f"Torch = {v} too old!")
|
||||
elif v <= V("2.1.1"):
|
||||
x = "cu{}{}-torch211"
|
||||
elif v <= V("2.1.2"):
|
||||
x = "cu{}{}-torch212"
|
||||
elif v < V("2.3.0"):
|
||||
x = "cu{}{}-torch220"
|
||||
elif v < V("2.4.0"):
|
||||
x = "cu{}{}-torch230"
|
||||
elif v < V("2.5.0"):
|
||||
x = "cu{}{}-torch240"
|
||||
elif v < V("2.6.0"):
|
||||
x = "cu{}{}-torch250"
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
uv_prefix = "uv " if use_uv else ""
|
||||
print(
|
||||
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||
)
|
||||
230
setup.py
230
setup.py
@@ -1,230 +0,0 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements(extras_require_map):
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
install_xformers = platform.machine() != "aarch64"
|
||||
if platform.machine() == "aarch64":
|
||||
# skip on ARM64
|
||||
skip_packages = [
|
||||
"torchao",
|
||||
"fla-core",
|
||||
"flash-linear-attention",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
print(
|
||||
_install_requires, [req in skip_packages for req in _install_requires]
|
||||
)
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
torch_parts = torch_version.split("+")
|
||||
if len(torch_parts) == 2:
|
||||
torch_cuda_version = torch_parts[1]
|
||||
_dependency_links.append(
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 10):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.5.0",
|
||||
"fbgemm-gpu-genai==1.5.0",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm>=0.19.0"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
extras_require_map["vllm"] = ["vllm==0.14.0"]
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
# vllm 0.9.x is incompatible with latest transformers
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
else:
|
||||
raise ValueError("axolotl requires torch>=2.4")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links, extras_require_map
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(
|
||||
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
version_ = fin.read().strip()
|
||||
return version_
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.8.3"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.18.2",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]>=2.52.1",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.10.0",
|
||||
],
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
)
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"axolotl=axolotl.cli.main:main",
|
||||
],
|
||||
},
|
||||
extras_require=extras_require_build,
|
||||
)
|
||||
@@ -339,7 +339,11 @@ def _build_peft_layer_and_get_delta(
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
delta = layer.get_delta_weight(adapter_name)
|
||||
# peft >=0.19.1 may return delta with transposed dims for 3D params
|
||||
if delta.shape != base_tensor.shape and delta.ndim == 3:
|
||||
delta = delta.transpose(1, 2).contiguous()
|
||||
return delta
|
||||
elif (
|
||||
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||
):
|
||||
|
||||
@@ -20,8 +20,16 @@ class DPOStrategy:
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.DPO:
|
||||
if cfg.dpo_loss_type is not None:
|
||||
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
|
||||
|
||||
if cfg.dpo_loss_weights is not None:
|
||||
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
|
||||
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = ["ipo"]
|
||||
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
|
||||
@@ -242,6 +242,85 @@ class ProducerConfig:
|
||||
)
|
||||
|
||||
|
||||
class _GroupShardedSampler:
|
||||
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
|
||||
|
||||
``RepeatSampler`` yields ``num_generations`` consecutive copies of
|
||||
each prompt, forming a GRPO group. For distributed training each
|
||||
rank must see a disjoint slice of prompts (otherwise every rank
|
||||
dogpiles on the first 1/world_size of the batch) while keeping each
|
||||
group intact on a single rank so advantage normalization sees all
|
||||
peer generations.
|
||||
|
||||
``accelerator.prepare(DataLoader)`` does not handle this correctly
|
||||
for custom samplers with ``split_batches=False`` (the default): it
|
||||
leaves the sampler alone and every rank replays identical indices.
|
||||
This wrapper fixes that by consuming the inner sampler's full
|
||||
output, chunking it into ``num_generations``-sized groups, and
|
||||
round-robining whole groups across ranks.
|
||||
|
||||
Intended to be used ONLY when distributed training is active
|
||||
(``num_replicas > 1``); for single-rank it is a no-op but still
|
||||
correct.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: Any,
|
||||
num_generations: int,
|
||||
rank: int,
|
||||
num_replicas: int,
|
||||
):
|
||||
if num_generations < 1:
|
||||
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
|
||||
if num_replicas < 1:
|
||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||
if not (0 <= rank < num_replicas):
|
||||
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
||||
self.inner = inner
|
||||
self.num_generations = num_generations
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
all_indices = list(self.inner)
|
||||
if len(all_indices) % self.num_generations != 0:
|
||||
raise ValueError(
|
||||
f"inner sampler yielded {len(all_indices)} indices, "
|
||||
f"not a multiple of num_generations={self.num_generations}"
|
||||
)
|
||||
# Chunk the flat index sequence into groups of num_generations
|
||||
# consecutive indices. ``RepeatSampler`` guarantees that each
|
||||
# group contains num_generations copies of the same prompt id.
|
||||
groups = [
|
||||
all_indices[i : i + self.num_generations]
|
||||
for i in range(0, len(all_indices), self.num_generations)
|
||||
]
|
||||
# Round-robin whole groups across ranks. Round-robin (vs.
|
||||
# contiguous chunking) preserves approximate shuffled order on
|
||||
# each rank even when the group count is small relative to the
|
||||
# world size.
|
||||
for group in groups[self.rank :: self.num_replicas]:
|
||||
yield from group
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
inner_len = len(self.inner)
|
||||
except TypeError:
|
||||
# Non-sized inner sampler — we can't know the per-rank
|
||||
# length without materializing. Return 0 as a hint that the
|
||||
# DataLoader should fall back to iteration.
|
||||
return 0
|
||||
total_groups = inner_len // self.num_generations
|
||||
# Ceiling division for the trailing groups that don't divide
|
||||
# evenly — extra groups go to the first ``total_groups %
|
||||
# num_replicas`` ranks, matching the round-robin above.
|
||||
my_groups = (
|
||||
total_groups + self.num_replicas - self.rank - 1
|
||||
) // self.num_replicas
|
||||
return my_groups * self.num_generations
|
||||
|
||||
|
||||
class DataProducer(ABC):
|
||||
"""Abstract base class for online data producers.
|
||||
|
||||
@@ -556,6 +635,34 @@ class GRPODataProducer(BaseDataProducer):
|
||||
seed=self._seed,
|
||||
)
|
||||
|
||||
# Shard the sampler across distributed ranks so each rank sees
|
||||
# a disjoint slice of prompts. ``RepeatSampler`` groups each
|
||||
# prompt with ``num_generations`` consecutive copies — our
|
||||
# wrapper round-robins WHOLE groups across ranks so all
|
||||
# generations of a given prompt stay on the same rank (needed
|
||||
# for GRPO advantage normalization within a group).
|
||||
#
|
||||
# Without this, ``accelerator.prepare(dl)`` with the default
|
||||
# ``split_batches=False`` leaves the custom sampler alone, so
|
||||
# every rank iterates the identical index sequence and the
|
||||
# cluster dogpiles on the first 1/world_size of the prompts.
|
||||
num_replicas = max(1, trainer.accelerator.num_processes)
|
||||
if num_replicas > 1:
|
||||
sampler = _GroupShardedSampler(
|
||||
inner=sampler,
|
||||
num_generations=self._num_generations,
|
||||
rank=trainer.accelerator.process_index,
|
||||
num_replicas=num_replicas,
|
||||
)
|
||||
logger.info(
|
||||
"[RANK:%d] _GroupShardedSampler active "
|
||||
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
|
||||
trainer.accelerator.process_index,
|
||||
num_replicas,
|
||||
self._num_generations,
|
||||
self._generation_batch_size,
|
||||
)
|
||||
|
||||
# Use identity collator (same as stock GRPOTrainer)
|
||||
def _identity(x):
|
||||
return x
|
||||
@@ -574,12 +681,11 @@ class GRPODataProducer(BaseDataProducer):
|
||||
rank=trainer.args.process_index,
|
||||
),
|
||||
)
|
||||
self._prompt_dl = trainer.accelerator.prepare(dl)
|
||||
|
||||
# Don't let accelerator track this dataloader
|
||||
acc_dls = trainer.accelerator._dataloaders
|
||||
if self._prompt_dl in acc_dls:
|
||||
acc_dls.remove(self._prompt_dl)
|
||||
# Skip accelerator.prepare — we're handling per-rank sharding
|
||||
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
|
||||
# otherwise try to wrap the DataLoader with its own sharding
|
||||
# logic which does not understand our group structure.
|
||||
self._prompt_dl = dl
|
||||
|
||||
self._prompt_iter = iter(self._prompt_dl)
|
||||
|
||||
@@ -1103,11 +1209,22 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
|
||||
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
|
||||
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
|
||||
|
||||
This is the canonical sync trigger and runs in BOTH async and
|
||||
synchronous modes from ``_prepare_inputs_with_data_producer`` /
|
||||
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
|
||||
patch is a parallel backup for non-data-producer paths (vanilla
|
||||
GRPO without NeMo Gym), where the data producer is bypassed
|
||||
entirely and TRL's stock generate-then-sync flow is used instead.
|
||||
"""
|
||||
if not (self.use_vllm and self.args.async_prefetch):
|
||||
if not self.use_vllm:
|
||||
return
|
||||
step = self.state.global_step
|
||||
interval = self.args.vllm_sync_interval
|
||||
# Default to syncing every step when no interval is configured —
|
||||
# otherwise ``step % None`` would TypeError, and the previous
|
||||
# behavior of crashing on the first sync was strictly worse than
|
||||
# the standard "sync every optimizer step".
|
||||
interval = self.args.vllm_sync_interval or 1
|
||||
if step != self._last_synced_step and step % interval == 0:
|
||||
if step == 0:
|
||||
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
|
||||
@@ -1202,13 +1319,42 @@ class AsyncGRPOTrainer(GRPOTrainer):
|
||||
|
||||
# Permanently replace vllm_generation.sync_weights with our custom
|
||||
# sync to avoid merge_adapter (fails on FP8 / races with training).
|
||||
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
|
||||
# handles the sync with proper interval tracking.
|
||||
#
|
||||
# The design has two modes that have to be threaded carefully:
|
||||
#
|
||||
# - Async prefetch ON: BG generation thread can't safely call
|
||||
# sync_weights mid-rollout (it races with the trainer's optimizer
|
||||
# step and can corrupt weights). We no-op the stock sync hook and
|
||||
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
|
||||
# optimizer step on the main thread.
|
||||
#
|
||||
# - Async prefetch OFF (synchronous mode): TRL's stock
|
||||
# ``_generate_single_turn`` calls ``sync_weights`` once per step
|
||||
# boundary. There's no BG thread to race with, and
|
||||
# ``_maybe_sync_vllm_weights`` short-circuits with
|
||||
# ``if not async_prefetch: return``, so we MUST wire the stock
|
||||
# hook directly to our LoRA sync helper — otherwise nothing ever
|
||||
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
|
||||
# keeps serving the base model, every rollout in every group
|
||||
# produces identical outputs, advantages are zero, optimizer
|
||||
# step gets skipped, repeat).
|
||||
if not getattr(self, "_patched_sync_weights", False):
|
||||
if self.use_vllm and hasattr(self, "vllm_generation"):
|
||||
if getattr(self.args, "vllm_lora_sync", False):
|
||||
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
if getattr(self.args, "async_prefetch", False):
|
||||
# Async: drive sync from main thread via
|
||||
# _maybe_sync_vllm_weights instead.
|
||||
self.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
# Sync mode: TRL's _generate_single_turn already
|
||||
# calls sync_weights once per step boundary. Wire
|
||||
# it directly to our LoRA filesystem sync helper.
|
||||
sync_helper = self._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
self.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
self._patched_sync_weights = True
|
||||
else:
|
||||
from accelerate.utils import is_peft_model
|
||||
|
||||
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Hatchery/Tinker remote training integration for Axolotl.
|
||||
|
||||
Routes axolotl's preprocessed data to a remote training API (Tinker or
|
||||
Hatchery) instead of running forward/backward locally. The remote
|
||||
service handles model weights, LoRA adapters, and gradient updates.
|
||||
"""
|
||||
|
||||
from .args import HatcheryArgs, HatcheryConfig
|
||||
from .plugin import HatcheryPlugin
|
||||
|
||||
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
|
||||
|
||||
# Usage:
|
||||
# plugins:
|
||||
# - axolotl.integrations.hatchery.HatcheryPlugin
|
||||
#
|
||||
# hatchery:
|
||||
# backend: tinker # or "hatchery"
|
||||
# lora_rank: 32
|
||||
# loss_fn: cross_entropy # SFT
|
||||
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
|
||||
#
|
||||
# learning_rate: 1e-4 # top-level, not under hatchery:
|
||||
62
src/axolotl/integrations/hatchery/args.py
Normal file
62
src/axolotl/integrations/hatchery/args.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Pydantic config schema for the Hatchery integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HatcheryConfig(BaseModel):
|
||||
"""Nested config under `hatchery:` in the axolotl YAML.
|
||||
|
||||
Only contains hatchery-specific settings. Standard training params
|
||||
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
|
||||
gradient_accumulation_steps) are read from axolotl's top-level config.
|
||||
"""
|
||||
|
||||
# Backend & connection
|
||||
backend: Literal["tinker", "hatchery"] = "tinker"
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
# LoRA config sent to remote
|
||||
lora_rank: int = Field(32, ge=1, le=256)
|
||||
train_attn: bool = True
|
||||
train_mlp: bool = True
|
||||
train_unembed: bool = True
|
||||
|
||||
# Loss function
|
||||
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
|
||||
"cross_entropy"
|
||||
)
|
||||
loss_fn_config: Optional[dict[str, Any]] = None
|
||||
|
||||
# Pipelining: submit next batch before awaiting previous result
|
||||
pipeline: bool = True
|
||||
|
||||
# Sampling params (for RL flows)
|
||||
max_sample_tokens: int = 256
|
||||
sample_temperature: float = 1.0
|
||||
num_samples: int = 4
|
||||
|
||||
# Reward functions (for RL) — list of fully qualified names
|
||||
reward_funcs: Optional[list[str]] = None
|
||||
|
||||
# Checkpointing
|
||||
save_steps: Optional[int] = None
|
||||
save_name_prefix: str = "checkpoint"
|
||||
|
||||
# Timeout per future (seconds)
|
||||
future_timeout: float = 600.0
|
||||
|
||||
|
||||
class HatcheryArgs(BaseModel):
|
||||
"""Top-level mixin that adds the nested `hatchery:` field."""
|
||||
|
||||
hatchery: Optional[HatcheryConfig] = None
|
||||
160
src/axolotl/integrations/hatchery/data.py
Normal file
160
src/axolotl/integrations/hatchery/data.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
|
||||
|
||||
Both Tinker and Hatchery expect the client to apply the causal LM shift:
|
||||
|
||||
Original tokens: [t0, t1, t2, ..., t_{L-1}]
|
||||
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
|
||||
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
|
||||
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
|
||||
|
||||
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
|
||||
"""Serialize a tensor to the TensorData wire dict."""
|
||||
flat = t.detach().cpu().flatten()
|
||||
dtype_map = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
}
|
||||
return {
|
||||
"dtype": dtype_map.get(flat.dtype, "float32"),
|
||||
"shape": list(t.shape),
|
||||
"data": flat.tolist(),
|
||||
}
|
||||
|
||||
|
||||
def _make_datum(
|
||||
tokens: list[int],
|
||||
loss_fn_inputs: dict[str, torch.Tensor],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
|
||||
return {
|
||||
"model_input": {
|
||||
"chunks": [{"type": "encoded_text", "tokens": tokens}],
|
||||
},
|
||||
"loss_fn_inputs": {
|
||||
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def datums_to_tinker(datums: list[dict[str, Any]]):
|
||||
"""Wrap plain-dict datums into tinker.types.Datum objects.
|
||||
|
||||
Both the Tinker SDK and updated Hatchery client accept these.
|
||||
"""
|
||||
import tinker.types as tt
|
||||
|
||||
result = []
|
||||
for d in datums:
|
||||
tokens = d["model_input"]["chunks"][0]["tokens"]
|
||||
tinker_inputs = {}
|
||||
for key, wire in d["loss_fn_inputs"].items():
|
||||
tinker_inputs[key] = tt.TensorData(
|
||||
data=wire["data"],
|
||||
dtype=wire["dtype"],
|
||||
shape=wire["shape"],
|
||||
)
|
||||
result.append(
|
||||
tt.Datum(
|
||||
model_input=tt.ModelInput.from_ints(tokens),
|
||||
loss_fn_inputs=tinker_inputs,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def batch_to_datums_sft(
|
||||
input_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
|
||||
batch_size = input_ids.size(0)
|
||||
datums = []
|
||||
|
||||
for i in range(batch_size):
|
||||
ids = input_ids[i]
|
||||
lbl = labels[i]
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = int(attention_mask[i].sum().item())
|
||||
ids = ids[:seq_len]
|
||||
lbl = lbl[:seq_len]
|
||||
|
||||
model_tokens = ids[:-1].tolist()
|
||||
shifted_labels = lbl[1:]
|
||||
|
||||
target_tokens = shifted_labels.clone()
|
||||
weights = (shifted_labels != -100).float()
|
||||
target_tokens[target_tokens == -100] = 0
|
||||
|
||||
datums.append(
|
||||
_make_datum(
|
||||
model_tokens,
|
||||
{
|
||||
"target_tokens": target_tokens,
|
||||
"weights": weights,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return datums
|
||||
|
||||
|
||||
def batch_to_datums_rl(
|
||||
input_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
|
||||
batch_size = input_ids.size(0)
|
||||
datums = []
|
||||
|
||||
for i in range(batch_size):
|
||||
ids = input_ids[i]
|
||||
lbl = labels[i]
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = int(attention_mask[i].sum().item())
|
||||
else:
|
||||
seq_len = ids.size(0)
|
||||
ids = ids[:seq_len]
|
||||
lbl = lbl[:seq_len]
|
||||
lp = logprobs[i, :seq_len]
|
||||
adv = advantages[i, :seq_len]
|
||||
|
||||
model_tokens = ids[:-1].tolist()
|
||||
|
||||
target_tokens = lbl[1:].clone()
|
||||
target_tokens[target_tokens == -100] = 0
|
||||
|
||||
datums.append(
|
||||
_make_datum(
|
||||
model_tokens,
|
||||
{
|
||||
"target_tokens": target_tokens,
|
||||
"logprobs": lp[1:],
|
||||
"advantages": adv[1:],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return datums
|
||||
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
|
||||
|
||||
Creates a dataset with chat-formatted prompts that include
|
||||
a hidden gold answer tag for the reward function.
|
||||
|
||||
Run:
|
||||
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def extract_boxed(text: str) -> str:
|
||||
match = re.search(r"\\boxed\{", text)
|
||||
if not match:
|
||||
return ""
|
||||
start = match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return text[start : i - 1] if depth == 0 else ""
|
||||
|
||||
|
||||
def main():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
|
||||
|
||||
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
|
||||
level = os.environ.get("MATH_LEVEL", "Level 1")
|
||||
filtered_rows = [x for x in ds if x["level"] == level]
|
||||
print(f"{level} algebra: {len(filtered_rows)} problems")
|
||||
|
||||
rows = []
|
||||
for prob in filtered_rows:
|
||||
gold = extract_boxed(prob["solution"])
|
||||
if not gold:
|
||||
continue
|
||||
|
||||
# Format as chat prompt with hidden gold tag
|
||||
prompt = (
|
||||
f"Solve the following math problem. "
|
||||
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
|
||||
f"{prob['problem']}"
|
||||
f"<|gold|>{gold}<|/gold|>"
|
||||
)
|
||||
|
||||
# Tokenize the prompt
|
||||
text = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"input_ids": prompt_ids,
|
||||
"labels": [-100] * len(prompt_ids),
|
||||
"attention_mask": [1] * len(prompt_ids),
|
||||
}
|
||||
)
|
||||
|
||||
out = Dataset.from_list(rows)
|
||||
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
|
||||
out.save_to_disk(out_dir)
|
||||
print(f"Saved {len(out)} examples to {out_dir}")
|
||||
if rows:
|
||||
print(
|
||||
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
|
||||
f"-{max(len(r['input_ids']) for r in rows)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
|
||||
#
|
||||
# Prep:
|
||||
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||
#
|
||||
# Run:
|
||||
# export TINKER_API_KEY="your-key"
|
||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
|
||||
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker
|
||||
lora_rank: 16
|
||||
loss_fn: importance_sampling
|
||||
max_sample_tokens: 2048
|
||||
sample_temperature: 0.7
|
||||
num_samples: 4
|
||||
pipeline: true
|
||||
save_steps: 5
|
||||
reward_funcs:
|
||||
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
|
||||
|
||||
datasets:
|
||||
- path: ./data/math_rl_level1
|
||||
ds_type: arrow
|
||||
type: completion
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
learning_rate: 5.0e-5
|
||||
optimizer: adamw_torch
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
weight_decay: 0.01
|
||||
max_grad_norm: 1.0
|
||||
|
||||
max_steps: 10
|
||||
num_epochs: 1
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
logging_steps: 1
|
||||
|
||||
output_dir: ./outputs/tinker-rl-math
|
||||
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
|
||||
#
|
||||
# Usage:
|
||||
# export TINKER_API_KEY="your-key"
|
||||
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
|
||||
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker
|
||||
lora_rank: 16
|
||||
loss_fn: cross_entropy
|
||||
pipeline: true
|
||||
save_steps: 10
|
||||
|
||||
datasets:
|
||||
- path: TeichAI/kimi-k2-thinking-1000x
|
||||
split: train[:50]
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
|
||||
chat_template: qwen3
|
||||
sequence_len: 2048
|
||||
|
||||
learning_rate: 3.0e-4
|
||||
optimizer: adamw_torch
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
weight_decay: 0.01
|
||||
max_grad_norm: 1.0
|
||||
|
||||
num_epochs: 1
|
||||
max_steps: 20
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
logging_steps: 1
|
||||
|
||||
output_dir: ./outputs/tinker-sft
|
||||
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoConfig, PreTrainedModel, Trainer
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class HatcheryPlugin(BasePlugin):
|
||||
"""Plugin that replaces local training with remote API calls.
|
||||
|
||||
Activated by adding to the axolotl YAML:
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||
|
||||
hatchery:
|
||||
backend: tinker # or "hatchery"
|
||||
lora_rank: 32
|
||||
loss_fn: cross_entropy
|
||||
# ... see HatcheryConfig for full options
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
return "axolotl.integrations.hatchery.args.HatcheryArgs"
|
||||
|
||||
def register(self, cfg: dict):
|
||||
"""Auto-set config values needed for remote training."""
|
||||
if cfg.get("remove_unused_columns") is None:
|
||||
cfg["remove_unused_columns"] = False
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
"""Replace model loading with a tiny stub."""
|
||||
hcfg = cfg.hatchery or {}
|
||||
backend = (
|
||||
hcfg.get("backend", "tinker")
|
||||
if isinstance(hcfg, dict)
|
||||
else getattr(hcfg, "backend", "tinker")
|
||||
)
|
||||
LOG.info(
|
||||
f"Hatchery plugin active: training dispatched to remote "
|
||||
f"{backend} API. Skipping local model weight loading."
|
||||
)
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
|
||||
def _stub_build_model(loader_self) -> bool:
|
||||
base_model = loader_self.cfg.base_model
|
||||
LOG.info(f"Skipping model weight loading for: {base_model}")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
|
||||
)
|
||||
|
||||
class _Stub(PreTrainedModel):
|
||||
config_class = type(config)
|
||||
_no_split_modules: list[str] = []
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
vocab_size = getattr(cfg, "vocab_size", 32000)
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
pass
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return None
|
||||
|
||||
loader_self.model = _Stub(config)
|
||||
return True
|
||||
|
||||
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
|
||||
|
||||
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||
"""Return the appropriate remote trainer class."""
|
||||
hcfg = cfg.hatchery
|
||||
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
|
||||
|
||||
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
|
||||
from .rl_trainer import HatcheryRLTrainer
|
||||
|
||||
return HatcheryRLTrainer
|
||||
|
||||
from .trainer import HatcheryTrainer
|
||||
|
||||
return HatcheryTrainer
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
model._hatchery_remote = True
|
||||
|
||||
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
LOG.info(
|
||||
"Hatchery: skipping local model save (weights are on remote API). "
|
||||
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
|
||||
)
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Inject hatchery config + axolotl training params into the trainer."""
|
||||
from .args import HatcheryConfig
|
||||
from .rl_trainer import HatcheryRLTrainer
|
||||
from .trainer import HatcheryTrainer
|
||||
|
||||
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
|
||||
return
|
||||
|
||||
hcfg = cfg.hatchery
|
||||
if isinstance(hcfg, dict):
|
||||
hatchery_config = HatcheryConfig(**hcfg)
|
||||
elif hcfg is None:
|
||||
hatchery_config = HatcheryConfig()
|
||||
else:
|
||||
hatchery_config = hcfg
|
||||
|
||||
trainer.hatchery_args = hatchery_config
|
||||
trainer._base_model_name = cfg.base_model
|
||||
|
||||
# Pull standard training params from axolotl config so they
|
||||
# don't need to be duplicated under hatchery:
|
||||
trainer._optim_params = {
|
||||
"learning_rate": cfg.learning_rate
|
||||
if cfg.learning_rate is not None
|
||||
else 1e-4,
|
||||
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
|
||||
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
|
||||
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
|
||||
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
"grad_clip_norm": cfg.max_grad_norm
|
||||
if cfg.max_grad_norm is not None
|
||||
else 0.0,
|
||||
}
|
||||
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Math reward function for hendrycks_math GRPO training.
|
||||
|
||||
Uses math_verify for robust answer comparison. Falls back to
|
||||
exact string match of \\boxed{} content only when math_verify
|
||||
is unavailable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_boxed(text: str) -> str | None:
|
||||
"""Extract \\boxed{...} answer handling nested braces."""
|
||||
match = re.search(r"\\boxed\{", text)
|
||||
if not match:
|
||||
return None
|
||||
start = match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return text[start : i - 1] if depth == 0 else None
|
||||
|
||||
|
||||
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
|
||||
"""Score completions by checking if \\boxed{} answer matches the gold answer.
|
||||
|
||||
The gold answer is extracted from the prompt (appended as a hidden
|
||||
tag by the dataset preprocessing). Format:
|
||||
... <|gold|>ANSWER<|/gold|>
|
||||
"""
|
||||
rewards = []
|
||||
for prompt, completion in zip(prompts, completions, strict=True):
|
||||
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
|
||||
if not gold_match:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
gold_answer = gold_match.group(1).strip()
|
||||
pred_answer = extract_boxed(completion)
|
||||
|
||||
if pred_answer is None:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
verified = None
|
||||
try:
|
||||
from math_verify import parse, verify
|
||||
|
||||
gold_parsed = parse(gold_answer)
|
||||
pred_parsed = parse(pred_answer)
|
||||
verified = verify(gold_parsed, pred_parsed)
|
||||
except Exception:
|
||||
LOG.debug(
|
||||
"math_verify unavailable or failed, using string fallback",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if verified is not None:
|
||||
rewards.append(1.0 if verified else 0.0)
|
||||
elif pred_answer.strip() == gold_answer.strip():
|
||||
rewards.append(1.0)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
|
||||
return rewards
|
||||
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
|
||||
|
||||
Full RL loop per step:
|
||||
1. Extract prompts from dataset batch
|
||||
2. Sample N completions per prompt via remote SamplingClient
|
||||
3. Score completions with local reward functions
|
||||
4. Compute GRPO-style advantages (per-group normalization)
|
||||
5. Send (prompt+completion, logprobs, advantages) as forward_backward
|
||||
6. Optimizer step
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers.trainer_utils import TrainOutput
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import HatcheryConfig
|
||||
from .data import batch_to_datums_rl, datums_to_tinker
|
||||
from .trainer import _create_training_client
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _load_reward_func(fqn: str) -> Callable:
|
||||
"""Load a reward function from a fully qualified name like 'module.func'."""
|
||||
module_path = ".".join(fqn.split(".")[:-1])
|
||||
func_name = fqn.split(".")[-1]
|
||||
mod = importlib.import_module(module_path)
|
||||
func = getattr(mod, func_name)
|
||||
if len(inspect.signature(func).parameters) < 2:
|
||||
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
|
||||
return func
|
||||
|
||||
|
||||
class HatcheryRLTrainer(AxolotlTrainer):
|
||||
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
|
||||
|
||||
hatchery_args: Optional[HatcheryConfig]
|
||||
_base_model_name: Optional[str]
|
||||
_training_client: Any
|
||||
_reward_functions: list[Callable]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hatchery_args = None
|
||||
self._base_model_name = None
|
||||
self._training_client = None
|
||||
self._reward_functions = []
|
||||
|
||||
def _ensure_reward_functions(self):
|
||||
if self._reward_functions:
|
||||
return
|
||||
args = self.hatchery_args
|
||||
if not args or not args.reward_funcs:
|
||||
raise ValueError(
|
||||
"No reward functions configured. Set hatchery.reward_funcs "
|
||||
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
|
||||
)
|
||||
for fqn in args.reward_funcs:
|
||||
self._reward_functions.append(_load_reward_func(fqn))
|
||||
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
|
||||
|
||||
def _get_training_client(self):
|
||||
if self._training_client is not None:
|
||||
return self._training_client
|
||||
|
||||
self._training_client = _create_training_client(
|
||||
self.hatchery_args, self._base_model_name
|
||||
)
|
||||
LOG.info(
|
||||
f"Remote RL session created: backend={self.hatchery_args.backend}, "
|
||||
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
|
||||
)
|
||||
return self._training_client
|
||||
|
||||
def _sample_completions(self, prompt_ids_list: list[list[int]]):
|
||||
"""Sample completions for prompts via remote API."""
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
results = []
|
||||
|
||||
sc = tc.save_weights_and_get_sampling_client()
|
||||
|
||||
for prompt_ids in prompt_ids_list:
|
||||
if hasattr(sc, "sampling_session_id"):
|
||||
sample_result = sc.sample(
|
||||
prompt_ids,
|
||||
max_tokens=args.max_sample_tokens,
|
||||
temperature=args.sample_temperature,
|
||||
n=args.num_samples,
|
||||
).result(timeout=args.future_timeout)
|
||||
else:
|
||||
mi = tt.ModelInput.from_ints(prompt_ids)
|
||||
sp = tt.SamplingParams(
|
||||
max_tokens=args.max_sample_tokens,
|
||||
temperature=args.sample_temperature,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
)
|
||||
sample_result = sc.sample(
|
||||
prompt=mi,
|
||||
num_samples=args.num_samples,
|
||||
sampling_params=sp,
|
||||
).result(timeout=args.future_timeout)
|
||||
|
||||
sequences = (
|
||||
sample_result.sequences
|
||||
if hasattr(sample_result, "sequences")
|
||||
else sample_result.get("sequences", [])
|
||||
)
|
||||
for seq in sequences:
|
||||
tokens = (
|
||||
list(seq.tokens)
|
||||
if hasattr(seq, "tokens")
|
||||
else seq.get("tokens", [])
|
||||
)
|
||||
logprobs = (
|
||||
list(seq.logprobs)
|
||||
if hasattr(seq, "logprobs") and seq.logprobs
|
||||
else seq.get("logprobs", [])
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"tokens": list(prompt_ids) + tokens,
|
||||
"completion_tokens": tokens,
|
||||
"logprobs": logprobs,
|
||||
"prompt_len": len(prompt_ids),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _compute_rewards(
|
||||
self, prompts: list[str], completions: list[str]
|
||||
) -> list[float]:
|
||||
total_rewards = [0.0] * len(completions)
|
||||
for reward_fn in self._reward_functions:
|
||||
rewards = reward_fn(prompts, completions)
|
||||
for i, r in enumerate(rewards):
|
||||
total_rewards[i] += r
|
||||
return total_rewards
|
||||
|
||||
@staticmethod
|
||||
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
|
||||
advantages = []
|
||||
for i in range(0, len(rewards), group_size):
|
||||
group = rewards[i : i + group_size]
|
||||
mean = sum(group) / len(group)
|
||||
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
|
||||
std = var**0.5 if var > 1e-8 else 1.0
|
||||
advantages.extend([(r - mean) / std for r in group])
|
||||
return advantages
|
||||
|
||||
def _do_optim_step(self):
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
trial: Any = None,
|
||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> TrainOutput:
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError("hatchery_args not configured")
|
||||
|
||||
self._ensure_reward_functions()
|
||||
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_train_epochs = int(self.args.num_train_epochs)
|
||||
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
|
||||
|
||||
LOG.info(
|
||||
f"Remote RL training: max_steps={max_steps}, "
|
||||
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
|
||||
)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = True
|
||||
self.state.is_world_process_zero = True
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
self.args,
|
||||
self.state,
|
||||
self.control, # type: ignore[has-type]
|
||||
)
|
||||
|
||||
tokenizer = self.processing_class
|
||||
global_step = 0
|
||||
total_loss = 0.0
|
||||
total_reward = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
for batch in train_dataloader:
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
prompt_ids_batch = batch["input_ids"]
|
||||
# Full prompt text (with gold tag) for reward scoring
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_ids_batch, skip_special_tokens=False
|
||||
)
|
||||
|
||||
# Strip <|gold|>...<|/gold|> from token ids before
|
||||
# sending to the model for sampling — the gold answer
|
||||
# must only be visible to the local reward function.
|
||||
sampling_prompts = []
|
||||
for prompt_text in prompt_texts:
|
||||
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
|
||||
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
|
||||
sampling_prompts.append(clean_ids)
|
||||
|
||||
# 1. Sample completions (without gold answer)
|
||||
t0 = time.time()
|
||||
samples = self._sample_completions(sampling_prompts)
|
||||
t_sample = time.time() - t0
|
||||
|
||||
if not samples:
|
||||
LOG.warning("No samples generated, skipping step")
|
||||
continue
|
||||
LOG.info(
|
||||
f"Sampled {len(samples)} completions, "
|
||||
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
|
||||
)
|
||||
|
||||
# 2. Decode and score
|
||||
completion_texts = [
|
||||
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
|
||||
for s in samples
|
||||
]
|
||||
sample_prompts = []
|
||||
for prompt_text in prompt_texts:
|
||||
sample_prompts.extend([prompt_text] * args.num_samples)
|
||||
|
||||
rewards = self._compute_rewards(sample_prompts, completion_texts)
|
||||
|
||||
# 3. GRPO advantages
|
||||
advantages_list = self._compute_advantages(
|
||||
rewards, group_size=args.num_samples
|
||||
)
|
||||
|
||||
# 4. Build training data
|
||||
all_datums = []
|
||||
for i, sample in enumerate(samples):
|
||||
full_tokens = sample["tokens"]
|
||||
prompt_len = sample["prompt_len"]
|
||||
seq_len = len(full_tokens)
|
||||
|
||||
input_ids = torch.tensor([full_tokens], dtype=torch.long)
|
||||
labels = torch.full((1, seq_len), -100, dtype=torch.long)
|
||||
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
|
||||
|
||||
logprobs_t = torch.zeros(1, seq_len)
|
||||
if sample["logprobs"]:
|
||||
lp = sample["logprobs"][: seq_len - prompt_len]
|
||||
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
|
||||
lp
|
||||
)
|
||||
|
||||
adv_t = torch.zeros(1, seq_len)
|
||||
adv_t[0, prompt_len:] = advantages_list[i]
|
||||
|
||||
all_datums.extend(
|
||||
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
|
||||
)
|
||||
|
||||
# 5. Forward backward (one datum at a time for memory) + optim
|
||||
t0 = time.time()
|
||||
tc = self._get_training_client()
|
||||
step_loss = 0.0
|
||||
for datum in all_datums:
|
||||
fb_future = tc.forward_backward(
|
||||
datums_to_tinker([datum]),
|
||||
loss_fn=args.loss_fn,
|
||||
loss_fn_config=args.loss_fn_config,
|
||||
)
|
||||
fb_result = fb_future.result(timeout=args.future_timeout)
|
||||
if hasattr(fb_result, "metrics"):
|
||||
step_loss += float(
|
||||
(fb_result.metrics or {}).get("loss:sum", 0.0)
|
||||
)
|
||||
elif isinstance(fb_result, dict):
|
||||
step_loss += float(
|
||||
fb_result.get("metrics", {}).get("loss:sum", 0.0)
|
||||
)
|
||||
optim_future = self._do_optim_step()
|
||||
if not args.pipeline:
|
||||
optim_future.result(timeout=args.future_timeout)
|
||||
t_train = time.time() - t0
|
||||
|
||||
mean_reward = sum(rewards) / len(rewards)
|
||||
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
|
||||
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
|
||||
global_step += 1
|
||||
total_loss += step_loss
|
||||
total_reward += mean_reward
|
||||
self.state.global_step = global_step
|
||||
|
||||
log_interval = self.args.logging_steps or 1
|
||||
if global_step % log_interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
LOG.info(
|
||||
f"[step {global_step}/{max_steps}] "
|
||||
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
|
||||
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
|
||||
f"sample={t_sample:.1f}s train={t_train:.1f}s "
|
||||
f"{elapsed / global_step:.1f}s/step"
|
||||
)
|
||||
self.log(
|
||||
{
|
||||
"loss": step_loss,
|
||||
"reward": mean_reward,
|
||||
"accuracy": accuracy,
|
||||
"mean_abs_advantage": mean_adv,
|
||||
"learning_rate": self._optim_params["learning_rate"],
|
||||
}
|
||||
)
|
||||
|
||||
if args.save_steps and global_step % args.save_steps == 0:
|
||||
self._save_remote_checkpoint(global_step)
|
||||
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if global_step > 0:
|
||||
self._save_remote_checkpoint(global_step, name="final")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / max(global_step, 1)
|
||||
avg_reward = total_reward / max(global_step, 1)
|
||||
|
||||
LOG.info(
|
||||
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
|
||||
f"avg_reward={avg_reward:.4f}"
|
||||
)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
return TrainOutput(
|
||||
global_step=global_step,
|
||||
training_loss=avg_loss,
|
||||
metrics={
|
||||
"train_loss": avg_loss,
|
||||
"train_reward": avg_reward,
|
||||
"train_runtime": elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||
try:
|
||||
future = tc.save_state(ckpt_name)
|
||||
future.result(timeout=args.future_timeout)
|
||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||
except Exception:
|
||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||
if name == "final":
|
||||
raise
|
||||
|
||||
def save_model(self, output_dir=None, _internal_call=False):
|
||||
self._save_remote_checkpoint(
|
||||
step=self.state.global_step,
|
||||
name=output_dir or "hf-save",
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
|
||||
)
|
||||
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Remote trainer that dispatches to Tinker or Hatchery API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from transformers.trainer_utils import TrainOutput
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import HatcheryConfig
|
||||
from .data import batch_to_datums_sft, datums_to_tinker
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_loss(result) -> float:
|
||||
"""Extract loss:sum from a forward_backward result.
|
||||
|
||||
Tinker's cross_entropy (and other losses) return the SUM of per-token
|
||||
losses, not the mean. This is by design — it lets users control
|
||||
normalization via the weights tensor. The trainer logs this raw sum;
|
||||
users who want per-token loss should divide by number of active tokens.
|
||||
"""
|
||||
if hasattr(result, "metrics"):
|
||||
metrics = result.metrics or {}
|
||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||
if isinstance(result, dict):
|
||||
metrics = result.get("metrics", {})
|
||||
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||
return 0.0
|
||||
|
||||
|
||||
def _create_training_client(args: HatcheryConfig, base_model: str):
|
||||
"""Create a training client for either Tinker or Hatchery backend."""
|
||||
if args.backend == "tinker":
|
||||
import tinker
|
||||
|
||||
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Tinker API key required. Set `hatchery.api_key` in config "
|
||||
"or TINKER_API_KEY env var."
|
||||
)
|
||||
os.environ["TINKER_API_KEY"] = api_key
|
||||
|
||||
service = tinker.ServiceClient(project_id=args.project_id)
|
||||
return service.create_lora_training_client(
|
||||
base_model=base_model,
|
||||
rank=args.lora_rank,
|
||||
train_mlp=args.train_mlp,
|
||||
train_attn=args.train_attn,
|
||||
train_unembed=args.train_unembed,
|
||||
)
|
||||
|
||||
from hatchery.core.client import HatcheryClient
|
||||
|
||||
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
|
||||
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
|
||||
|
||||
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
|
||||
return client.create_lora_training_client(
|
||||
base_model=base_model,
|
||||
rank=args.lora_rank,
|
||||
train_attn=args.train_attn,
|
||||
train_mlp=args.train_mlp,
|
||||
train_unembed=args.train_unembed,
|
||||
)
|
||||
|
||||
|
||||
class HatcheryTrainer(AxolotlTrainer):
|
||||
"""Trainer that sends preprocessed batches to a remote training API.
|
||||
|
||||
Replaces local forward/backward with remote API calls to Tinker or
|
||||
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
|
||||
chat templates, packing, etc.) but offloads compute to remote GPUs.
|
||||
"""
|
||||
|
||||
hatchery_args: Optional[HatcheryConfig]
|
||||
_base_model_name: Optional[str]
|
||||
_training_client: Any
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hatchery_args = None
|
||||
self._base_model_name = None
|
||||
self._training_client = None
|
||||
|
||||
def _get_training_client(self):
|
||||
"""Lazily create the remote training session."""
|
||||
if self._training_client is not None:
|
||||
return self._training_client
|
||||
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError(
|
||||
"HatcheryTrainer.hatchery_args not set. "
|
||||
"Ensure the HatcheryPlugin is registered."
|
||||
)
|
||||
|
||||
base_model = self._base_model_name
|
||||
if not base_model:
|
||||
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
|
||||
|
||||
self._training_client = _create_training_client(args, base_model)
|
||||
|
||||
LOG.info(
|
||||
f"Remote training session created: backend={args.backend}, "
|
||||
f"model={base_model}, rank={args.lora_rank}"
|
||||
)
|
||||
return self._training_client
|
||||
|
||||
def _send_batch(self, batch: dict[str, torch.Tensor]):
|
||||
"""Convert batch to datums and send forward_backward to remote.
|
||||
|
||||
Returns (future, n_active_tokens) where n_active_tokens counts
|
||||
the completion tokens in this batch (for loss normalization).
|
||||
"""
|
||||
input_ids = batch["input_ids"]
|
||||
labels = batch["labels"]
|
||||
attention_mask = batch.get("attention_mask")
|
||||
|
||||
n_active = int((labels[:, 1:] != -100).sum().item())
|
||||
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
|
||||
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
send_datums = datums_to_tinker(datums)
|
||||
|
||||
future = tc.forward_backward(
|
||||
send_datums,
|
||||
loss_fn=args.loss_fn,
|
||||
loss_fn_config=args.loss_fn_config,
|
||||
)
|
||||
return future, n_active
|
||||
|
||||
def _do_optim_step(self):
|
||||
"""Send optimizer step to remote using axolotl's training params."""
|
||||
import tinker.types as tt
|
||||
|
||||
tc = self._get_training_client()
|
||||
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
trial: Any = None,
|
||||
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> TrainOutput:
|
||||
"""Main training loop — sends batches to remote API."""
|
||||
args = self.hatchery_args
|
||||
if args is None:
|
||||
raise RuntimeError("hatchery_args not configured")
|
||||
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
num_batches = len(train_dataloader)
|
||||
|
||||
grad_accum = self.args.gradient_accumulation_steps
|
||||
num_train_epochs = int(self.args.num_train_epochs)
|
||||
steps_per_epoch = max(num_batches // grad_accum, 1)
|
||||
max_steps = (
|
||||
self.args.max_steps
|
||||
if self.args.max_steps > 0
|
||||
else steps_per_epoch * num_train_epochs
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Remote training: {num_batches} batches/epoch, "
|
||||
f"{grad_accum} grad_accum, {max_steps} max steps, "
|
||||
f"{num_train_epochs} epochs"
|
||||
)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = True
|
||||
self.state.is_world_process_zero = True
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(
|
||||
self.args,
|
||||
self.state,
|
||||
self.control, # type: ignore[has-type]
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
total_loss = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
pending_fb_futures = []
|
||||
accum_count = 0
|
||||
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
if global_step >= max_steps:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_step_begin(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
fb_future, n_active = self._send_batch(batch)
|
||||
pending_fb_futures.append((fb_future, n_active))
|
||||
accum_count += 1
|
||||
|
||||
if accum_count >= grad_accum:
|
||||
step_loss_sum = 0.0
|
||||
step_active = 0
|
||||
for fut, n_act in pending_fb_futures:
|
||||
result = fut.result(timeout=args.future_timeout)
|
||||
step_loss_sum += _extract_loss(result)
|
||||
step_active += n_act
|
||||
|
||||
optim_future = self._do_optim_step()
|
||||
if not args.pipeline:
|
||||
optim_future.result(timeout=args.future_timeout)
|
||||
|
||||
step_loss = (
|
||||
step_loss_sum / step_active
|
||||
if step_active > 0
|
||||
else step_loss_sum
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
total_loss += step_loss
|
||||
self.state.global_step = global_step
|
||||
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
|
||||
|
||||
log_interval = self.args.logging_steps or 1
|
||||
if global_step % log_interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / global_step
|
||||
LOG.info(
|
||||
f"[step {global_step}/{max_steps}] "
|
||||
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
|
||||
f"active={step_active} "
|
||||
f"{elapsed / global_step:.2f}s/step"
|
||||
)
|
||||
self.log(
|
||||
{
|
||||
"loss": step_loss,
|
||||
"learning_rate": self._optim_params["learning_rate"],
|
||||
"epoch": self.state.epoch,
|
||||
}
|
||||
)
|
||||
|
||||
if args.save_steps and global_step % args.save_steps == 0:
|
||||
self._save_remote_checkpoint(global_step)
|
||||
|
||||
self.control = self.callback_handler.on_step_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
pending_fb_futures = []
|
||||
accum_count = 0
|
||||
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if global_step > 0:
|
||||
self._save_remote_checkpoint(global_step, name="final")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
avg_loss = total_loss / max(global_step, 1)
|
||||
|
||||
LOG.info(
|
||||
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
|
||||
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
|
||||
)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(
|
||||
self.args, self.state, self.control
|
||||
)
|
||||
|
||||
return TrainOutput(
|
||||
global_step=global_step,
|
||||
training_loss=avg_loss,
|
||||
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
|
||||
)
|
||||
|
||||
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||
"""Save a checkpoint on the remote service."""
|
||||
tc = self._get_training_client()
|
||||
args = self.hatchery_args
|
||||
assert args is not None # validated by _get_training_client
|
||||
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||
try:
|
||||
future = tc.save_state(ckpt_name)
|
||||
future.result(timeout=args.future_timeout)
|
||||
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||
except Exception:
|
||||
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||
if name == "final":
|
||||
raise
|
||||
|
||||
def save_model(self, output_dir=None, _internal_call=False):
|
||||
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
|
||||
self._save_remote_checkpoint(
|
||||
step=self.state.global_step,
|
||||
name=output_dir or "hf-save",
|
||||
)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"HatcheryTrainer uses remote API; compute_loss should not be called."
|
||||
)
|
||||
@@ -60,49 +60,14 @@ def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
||||
"""Convert peft LoRA weights to scattermoe layout.
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
||||
are swapped relative to scattermoe's convention.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
||||
Uses vectorized tensor operations (no Python loop over experts).
|
||||
|
||||
Works for **both** gate_up_proj and down_proj since the transposition
|
||||
issue is the same for any parameter.
|
||||
peft >=0.19.1 assigns in/out features for 3D params such that
|
||||
A and B already align with scattermoe's convention (no A<->B swap).
|
||||
Only B needs rank-major → expert-major layout conversion.
|
||||
"""
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
||||
|
||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
||||
smoe_A = (
|
||||
peft_B_em.reshape(dim2, num_experts, rank)
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.reshape(rank * num_experts, dim2)
|
||||
)
|
||||
|
||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
||||
smoe_B = (
|
||||
peft_A.reshape(num_experts, rank, dim1)
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
.reshape(dim1, num_experts * rank)
|
||||
)
|
||||
|
||||
smoe_A = peft_A
|
||||
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
|
||||
@@ -110,11 +110,36 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
item["agent_ref"] = full_item["agent_ref"]
|
||||
dataset_items.append(item)
|
||||
|
||||
# Expand by num_generations (agent produces one rollout per call)
|
||||
expanded_items = []
|
||||
for item in dataset_items:
|
||||
for _ in range(self._num_generations):
|
||||
expanded_items.append(item)
|
||||
# NOTE: do NOT re-expand by num_generations here.
|
||||
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
|
||||
# yields ``num_generations`` consecutive copies of each unique
|
||||
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
|
||||
# num_generations)`` items — one entry per rollout. Expanding
|
||||
# again here would fire ``num_generations^2`` rollouts per
|
||||
# prompt per rank and make every step dogpile on a handful of
|
||||
# tasks.
|
||||
expanded_items = dataset_items
|
||||
|
||||
# Diagnostic: log what this rank is about to fire.
|
||||
try:
|
||||
import collections
|
||||
|
||||
iid_counts: collections.Counter[str | None] = collections.Counter()
|
||||
for it in dataset_items:
|
||||
iid_counts[
|
||||
(it.get("responses_create_params", {}).get("metadata") or {}).get(
|
||||
"instance_id"
|
||||
)
|
||||
] += 1
|
||||
LOG.info(
|
||||
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
|
||||
trainer.accelerator.process_index,
|
||||
len(dataset_items),
|
||||
len(iid_counts),
|
||||
list(iid_counts.most_common(5)),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -140,6 +165,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
logprobs_list = []
|
||||
rewards_list = []
|
||||
|
||||
num_turns_list: list[int] = []
|
||||
for resp in responses:
|
||||
parsed = _parse_agent_response(resp, eos_token_id)
|
||||
prompt_ids_list.append(parsed["prompt_ids"])
|
||||
@@ -147,6 +173,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
env_mask_list.append(parsed["env_mask"])
|
||||
logprobs_list.append(parsed["logprobs"])
|
||||
rewards_list.append(parsed["reward"])
|
||||
num_turns_list.append(parsed.get("num_turns", 0))
|
||||
|
||||
# Pad to tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@@ -179,22 +206,48 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||
|
||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||
# Inject per-rollout reward + num_turns into each input. Since
|
||||
# ``RepeatSampler`` already yields ``num_generations`` copies of
|
||||
# each prompt, ``inputs`` has ONE entry per rollout (matching
|
||||
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
|
||||
# GRPO advantage normalization is the trainer's job downstream.
|
||||
assert len(inputs) == len(rewards_list), (
|
||||
f"rewards/inputs length mismatch: "
|
||||
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
|
||||
)
|
||||
for i, inp in enumerate(inputs):
|
||||
# Each input gets rewards for its num_generations rollouts
|
||||
start = i * self._num_generations
|
||||
end = start + self._num_generations
|
||||
inp["env_reward"] = rewards_list[start:end]
|
||||
inp["env_reward"] = rewards_list[i]
|
||||
inp["num_turns"] = num_turns_list[i]
|
||||
|
||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||
expanded_inputs = []
|
||||
for inp in inputs:
|
||||
for g in range(self._num_generations):
|
||||
expanded_inp = dict(inp)
|
||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||
expanded_inputs.append(expanded_inp)
|
||||
# One expanded_input per rollout (already correct count because
|
||||
# inputs has num_generations copies baked in by the sampler).
|
||||
expanded_inputs = [dict(inp) for inp in inputs]
|
||||
|
||||
# Log rollout-level stats to wandb from rank 0. These are the
|
||||
# true agent-side metrics (not the tokenized TRL view) — so
|
||||
# num_turns reflects how many /run iterations each rollout
|
||||
# actually took before finishing or hitting max_turns.
|
||||
if is_main and num_turns_list:
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
import statistics as _stats
|
||||
|
||||
nonzero = sum(1 for r in rewards_list if r > 0)
|
||||
log_payload = {
|
||||
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
|
||||
"rollout/num_turns/min": float(min(num_turns_list)),
|
||||
"rollout/num_turns/max": float(max(num_turns_list)),
|
||||
"rollout/reward/mean": float(_stats.mean(rewards_list)),
|
||||
"rollout/reward/nonzero_frac": (
|
||||
nonzero / len(rewards_list) if rewards_list else 0.0
|
||||
),
|
||||
"rollout/n_samples": float(len(rewards_list)),
|
||||
}
|
||||
wandb.log(log_payload, commit=False)
|
||||
except Exception as exc: # never let metric logging break training
|
||||
LOG.warning("rollout wandb log failed: %s", exc)
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
|
||||
@@ -19,6 +19,7 @@ Supports two modes:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# ---- vLLM weight-sync transport probe ------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLLMWeightSyncCapabilities:
|
||||
"""What weight-sync routes a vLLM server actually exposes.
|
||||
|
||||
Discovered once at ``pre_model_load`` time by fetching the server's
|
||||
``/openapi.json``. Drives the transport-selection table below.
|
||||
"""
|
||||
|
||||
nccl: bool = False # /init_communicator/ + /update_named_param/
|
||||
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
|
||||
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
|
||||
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
|
||||
probed: bool = False
|
||||
probe_error: str | None = None
|
||||
routes: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def any_full_param_sync(self) -> bool:
|
||||
"""True if at least one transport can push full-model weights."""
|
||||
return self.nccl or self.http_full
|
||||
|
||||
@property
|
||||
def any_lora_sync(self) -> bool:
|
||||
"""True if at least one transport can push LoRA adapters."""
|
||||
return self.lora_filesystem or self.lora_axolotl or self.nccl
|
||||
|
||||
|
||||
def probe_vllm_weight_sync(
|
||||
base_url: str, timeout: float = 5.0
|
||||
) -> VLLMWeightSyncCapabilities:
|
||||
"""Detect which weight-sync routes the configured vLLM server exposes.
|
||||
|
||||
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
|
||||
we care about is mounted as a POST route there. Falls back to all-False
|
||||
on any error so the caller can still decide what to do (typically: raise
|
||||
a clear error rather than silently no-op).
|
||||
"""
|
||||
import requests
|
||||
|
||||
caps = VLLMWeightSyncCapabilities()
|
||||
try:
|
||||
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
|
||||
r.raise_for_status()
|
||||
spec = r.json()
|
||||
routes = sorted((spec.get("paths") or {}).keys())
|
||||
caps.routes = routes
|
||||
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
|
||||
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
|
||||
caps.lora_axolotl = "/set_lora_adapter/" in routes
|
||||
caps.http_full = "/http_update_weights/" in routes
|
||||
caps.probed = True
|
||||
except Exception as exc:
|
||||
caps.probe_error = f"{type(exc).__name__}: {exc}"
|
||||
LOG.warning(
|
||||
"NeMo Gym: failed to probe vLLM /openapi.json at %s — %s. "
|
||||
"Will fall back to LoRA-only behavior.",
|
||||
base_url,
|
||||
caps.probe_error,
|
||||
)
|
||||
return caps
|
||||
|
||||
|
||||
def select_weight_sync_transport(
|
||||
caps: VLLMWeightSyncCapabilities,
|
||||
*,
|
||||
has_lora: bool,
|
||||
vllm_lora_sync_pref: bool,
|
||||
) -> str:
|
||||
"""Pick the right transport for a (server caps, model type) combo.
|
||||
|
||||
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
|
||||
``"none"``. The caller decides what to do with ``"none"`` (typically:
|
||||
raise an error explaining the misconfiguration).
|
||||
|
||||
Selection table:
|
||||
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
|
||||
LoRA model + lora endpoint → lora_filesystem
|
||||
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
|
||||
Full model + nccl endpoint → nccl
|
||||
Full model + http endpoint → http_full
|
||||
anything else → none
|
||||
"""
|
||||
if has_lora:
|
||||
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
|
||||
return "lora_filesystem"
|
||||
if caps.lora_filesystem or caps.lora_axolotl:
|
||||
return "lora_filesystem"
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
return "none"
|
||||
# Full-parameter model
|
||||
if caps.nccl:
|
||||
return "nccl"
|
||||
if caps.http_full:
|
||||
return "http_full"
|
||||
return "none"
|
||||
|
||||
|
||||
class NemoGymPlugin(BasePlugin):
|
||||
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
|
||||
|
||||
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
|
||||
self._reward_fn = None
|
||||
self._dataset_lookup = None
|
||||
self._agent_servers = {}
|
||||
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.nemo_gym.NemoGymArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply monkeypatches before trainer creation."""
|
||||
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
|
||||
|
||||
Replaces the previous unconditional ``init_communicator`` monkey-patch
|
||||
with a probe of the configured vLLM server's ``/openapi.json``. We only
|
||||
bypass NCCL init when the server we're talking to actually lacks the
|
||||
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
|
||||
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
|
||||
standard TRL flow alone so full-finetune training can sync weights.
|
||||
"""
|
||||
if not cfg.nemo_gym_enabled:
|
||||
return
|
||||
|
||||
# Always skip NCCL communicator init in NeMo Gym mode.
|
||||
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
|
||||
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
|
||||
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
|
||||
return
|
||||
|
||||
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
|
||||
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
|
||||
base_url = f"http://{host}:{port}"
|
||||
self._vllm_caps = probe_vllm_weight_sync(base_url)
|
||||
|
||||
if self._vllm_caps.probed:
|
||||
LOG.info(
|
||||
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
|
||||
"lora_axolotl=%s http_full=%s",
|
||||
base_url,
|
||||
self._vllm_caps.nccl,
|
||||
self._vllm_caps.lora_filesystem,
|
||||
self._vllm_caps.lora_axolotl,
|
||||
self._vllm_caps.http_full,
|
||||
)
|
||||
|
||||
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
|
||||
# available we leave VLLMClient.init_communicator alone so the
|
||||
# standard TRL sync flow can run for full-parameter training.
|
||||
if not self._vllm_caps.nccl:
|
||||
self._patch_skip_nccl_init()
|
||||
|
||||
def _patch_skip_nccl_init(self):
|
||||
"""Monkeypatch VLLMClient.init_communicator to no-op.
|
||||
|
||||
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
|
||||
serve script). The NCCL communicator is not needed and fails with both
|
||||
vLLM V1 engine and standard OpenAI server mode.
|
||||
Only called when the configured vLLM server doesn't expose
|
||||
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
|
||||
TRL's standard ``init_communicator`` would 404 inside trainer
|
||||
construction; we no-op it so the LoRA filesystem path can install
|
||||
its own sync in ``post_trainer_create``.
|
||||
"""
|
||||
try:
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
VLLMClient._original_init_communicator = VLLMClient.init_communicator
|
||||
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
|
||||
"Skipping NCCL init_communicator (LoRA sync mode)"
|
||||
"Skipping NCCL init_communicator (server has no /init_communicator/)"
|
||||
)
|
||||
LOG.info(
|
||||
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
|
||||
)
|
||||
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to patch VLLMClient: {exc}")
|
||||
|
||||
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
|
||||
verify_timeout = cfg.nemo_gym_verify_timeout or 30
|
||||
multi_turn = cfg.nemo_gym_multi_turn or False
|
||||
|
||||
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
|
||||
# - Install LoRA sync (when vllm_lora_sync=True)
|
||||
# - Or no-op sync_weights (when using standard vLLM server)
|
||||
# Pick a weight-sync transport based on what the configured vLLM
|
||||
# server actually exposes (see ``pre_model_load`` probe) and what
|
||||
# kind of model we're training. The selection table is documented
|
||||
# in ``select_weight_sync_transport``.
|
||||
trl_cfg = getattr(cfg, "trl", None)
|
||||
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
|
||||
vllm_gen = trainer.vllm_generation
|
||||
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
|
||||
adapter = getattr(cfg, "adapter", None)
|
||||
has_lora = adapter in ("lora", "qlora")
|
||||
vllm_lora_sync_pref = bool(
|
||||
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
|
||||
)
|
||||
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
|
||||
transport = select_weight_sync_transport(
|
||||
caps,
|
||||
has_lora=has_lora,
|
||||
vllm_lora_sync_pref=vllm_lora_sync_pref,
|
||||
)
|
||||
|
||||
if transport == "lora_filesystem":
|
||||
self._setup_lora_sync(trainer)
|
||||
# Verify the vLLM server supports runtime LoRA loading
|
||||
self._check_lora_endpoint(vllm_gen)
|
||||
else:
|
||||
# No NCCL, no LoRA sync — skip all weight sync paths
|
||||
vllm_gen.sync_weights = lambda: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
LOG.info("NeMo Gym weight sync: LoRA filesystem")
|
||||
elif transport == "nccl":
|
||||
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
|
||||
# alone (pre_model_load only patched it when the probe found no
|
||||
# NCCL route) so the trainer's normal weight-sync flow runs.
|
||||
LOG.info(
|
||||
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
|
||||
)
|
||||
type(vllm_gen).sync_weights = lambda self: LOG.debug(
|
||||
"Weight sync skipped (NeMo Gym mode)"
|
||||
elif transport == "http_full":
|
||||
# Full-parameter HTTP sync — implementation lands in step 3.
|
||||
# For now, fail loudly so users know the path is detected but
|
||||
# not yet wired up, instead of silently no-oping like before.
|
||||
raise NotImplementedError(
|
||||
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
|
||||
"but the client-side sync helper is not yet implemented "
|
||||
"(planned). Use `adapter: lora|qlora` for now, or use a "
|
||||
"vLLM serve module that exposes /init_communicator/ for "
|
||||
"NCCL sync."
|
||||
)
|
||||
# Also patch the async trainer's internal sync method
|
||||
if hasattr(trainer, "_maybe_sync_vllm_weights"):
|
||||
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
|
||||
"Async weight sync skipped (NeMo Gym mode)"
|
||||
else: # transport == "none"
|
||||
# No viable sync path. Build a precise error so the user knows
|
||||
# exactly what's missing and how to fix it.
|
||||
if not caps.probed:
|
||||
msg = (
|
||||
"could not probe the vLLM server's "
|
||||
f"/openapi.json: {caps.probe_error}. "
|
||||
"Verify that vLLM is reachable at "
|
||||
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
|
||||
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
|
||||
)
|
||||
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
|
||||
elif has_lora:
|
||||
msg = (
|
||||
"the vLLM server has neither NCCL routes "
|
||||
"(/init_communicator/) nor a LoRA-loading route "
|
||||
"(/v1/load_lora_adapter or /set_lora_adapter/). "
|
||||
"Restart vLLM with `--enable-lora --max-lora-rank N "
|
||||
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
|
||||
"server, or use `axolotl vllm-serve` for the "
|
||||
"NCCL-capable serve module."
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"the vLLM server exposes no full-parameter sync route "
|
||||
"(/init_communicator/ for NCCL or /http_update_weights/ "
|
||||
"for HTTP). Use `axolotl vllm-serve` (which has both) "
|
||||
"or set `adapter: lora|qlora`."
|
||||
)
|
||||
raise ValueError(
|
||||
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
|
||||
"weight sync the trainer's gradient updates never reach the "
|
||||
"rollout policy (functionally a no-op trainer)."
|
||||
)
|
||||
|
||||
if multi_turn:
|
||||
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)
|
||||
|
||||
@@ -130,21 +130,41 @@ def start_servers(
|
||||
)
|
||||
|
||||
|
||||
def get_server_configs(head_port: int = 11000) -> dict:
|
||||
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
|
||||
"""Fetch the global config from the NeMo Gym head server.
|
||||
|
||||
Retries up to 3 times with exponential backoff. The default per-attempt
|
||||
timeout is 30s (raised from the original 5s) because head servers can
|
||||
be slow to respond when they're concurrently serving rollouts from a
|
||||
prior training run. A 5s timeout was empirically too tight to survive
|
||||
a kill-and-relaunch cycle.
|
||||
|
||||
Returns:
|
||||
Dict mapping server_name -> server config.
|
||||
"""
|
||||
response = requests.get(
|
||||
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
|
||||
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
|
||||
last_exc: Exception | None = None
|
||||
for attempt in (1, 2, 3):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
result = yaml.safe_load(response.text)
|
||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||
if isinstance(result, str):
|
||||
result = yaml.safe_load(result)
|
||||
return result
|
||||
except (requests.exceptions.RequestException, OSError) as exc:
|
||||
last_exc = exc
|
||||
LOG.warning(
|
||||
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
|
||||
attempt,
|
||||
type(exc).__name__,
|
||||
)
|
||||
if attempt < 3:
|
||||
time.sleep(2.0 * attempt)
|
||||
raise RuntimeError(
|
||||
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = yaml.safe_load(response.text)
|
||||
# NeMo Gym head server double-encodes: YAML string inside a YAML string
|
||||
if isinstance(result, str):
|
||||
result = yaml.safe_load(result)
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_servers(
|
||||
|
||||
@@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel(
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
@@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel(
|
||||
):
|
||||
"""
|
||||
Fused forward:
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
|
||||
y[..., :n_rot] = rope(x_norm[..., :n_rot])
|
||||
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
|
||||
|
||||
rotate_half swaps first/second halves and negates the first:
|
||||
rotate_half([a, b]) = [-b, a]
|
||||
rotate_half swaps first/second halves and negates the first, restricted
|
||||
to the rotary span [0, n_rot):
|
||||
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
|
||||
|
||||
For the partial-rotary pass-through region we load cos with default 1.0
|
||||
and sin with default 0.0 outside [0, n_rot), so the same formula
|
||||
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
|
||||
|
||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
|
||||
cs_row_idx = row_idx // n_heads
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
# Load input row
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
# RMSNorm: compute 1/rms
|
||||
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
@@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel(
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_norm = X_norm * W_row
|
||||
|
||||
# RoPE: load cos/sin (broadcast across heads)
|
||||
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
|
||||
# cos=1, sin=0 so the formula leaves X_norm untouched.
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin_row = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||
# for col >= half_dim, take X_norm[col - half_dim]
|
||||
# rotate_half within [0, n_rot):
|
||||
# for col < half_rot: take -X_norm[col + half_rot]
|
||||
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
|
||||
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
X_rot = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
|
||||
).to(tl.float32)
|
||||
# Re-normalize the rotated values
|
||||
X_rot_norm = X_rot * rstd
|
||||
if HAS_WEIGHT:
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||
tl.float32
|
||||
)
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
|
||||
X_rot_norm = X_rot_norm * W_rot
|
||||
|
||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
|
||||
X_rot_norm = X_rot_norm * sign
|
||||
|
||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
@@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel(
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = RoPE(RMSNorm(X, W))
|
||||
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
|
||||
(`n_rot <= n_cols`).
|
||||
|
||||
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
|
||||
output is just the normalized row, so dN[col] = dY[col] (achieved by
|
||||
loading cos with default 1.0 and forcing the rotate-half contribution
|
||||
to zero outside the rotary span).
|
||||
|
||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
@@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel(
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
@@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel(
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
|
||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||
#
|
||||
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||
# This is equivalent to rotating (dY * sin) because the rotation
|
||||
# just permutes which elements are multiplied.
|
||||
# For col >= n_rot the formula must collapse to dN = dY (since the
|
||||
# forward is just a pass-through). cos defaults to 1.0 above; the
|
||||
# rotate-half contribution is masked to zero below.
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
dY_rot = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
sin_rot = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
|
||||
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
|
||||
rotate_term = dY_rot * sin_rot * adj_sign
|
||||
# Zero out rotate-half contribution outside the rotary span.
|
||||
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
|
||||
dN = dY_row * cos_row + rotate_term
|
||||
|
||||
# Pre-weight normalized: n = rstd * x
|
||||
n = X_row * rstd
|
||||
@@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel(
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
Args:
|
||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||
W: (head_dim,) or None — RMSNorm weight
|
||||
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
eps: float
|
||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
|
||||
partial rotary). Must be even and ``<= head_dim``.
|
||||
Returns:
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||
"""
|
||||
@@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||
def rms_norm_rope_backward(
|
||||
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
|
||||
):
|
||||
n_rows, n_cols = dY.shape
|
||||
has_weight = W is not None
|
||||
|
||||
@@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, head_dim) — broadcast across heads
|
||||
sin: (B*S, head_dim) — broadcast across heads
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, n_rot) — broadcast across heads
|
||||
sin: (B*S, n_rot) — broadcast across heads
|
||||
n_heads: int
|
||||
n_rot: int — rotary dim (<= head_dim)
|
||||
"""
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||
X,
|
||||
@@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
eps,
|
||||
n_heads,
|
||||
n_rot,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.n_heads = n_heads
|
||||
ctx.n_rot = n_rot
|
||||
ctx.has_weight = W is not None
|
||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||
return Y
|
||||
@@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
RSTD,
|
||||
ctx.n_heads,
|
||||
ctx.n_rot,
|
||||
ctx.BLOCK_SIZE,
|
||||
ctx.num_warps,
|
||||
)
|
||||
return dX, dW, None, None, None, None
|
||||
return dX, dW, None, None, None, None, None
|
||||
|
||||
|
||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
Apply fused RMSNorm + RoPE.
|
||||
Apply fused RMSNorm + (partial) RoPE.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
|
||||
must be even and ``<= head_dim``. When ``n_rot < head_dim``
|
||||
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
|
||||
(partial-rotary pass-through), matching stock Gemma 4 with
|
||||
``partial_rotary_factor < 1.0``.
|
||||
sin: (batch, seq_len, n_rot) — same shape as ``cos``
|
||||
eps: float — RMSNorm epsilon
|
||||
|
||||
Returns:
|
||||
@@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
shape = x.shape # (B, S, H, D)
|
||||
B, S, H, D = shape
|
||||
n_rot = cos.shape[-1]
|
||||
if sin.shape[-1] != n_rot:
|
||||
raise ValueError(
|
||||
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
|
||||
f"sin={sin.shape[-1]}"
|
||||
)
|
||||
if n_rot > D:
|
||||
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
|
||||
if n_rot % 2 != 0:
|
||||
raise ValueError(f"rotary dim must be even, got {n_rot}")
|
||||
|
||||
# Flatten to 2D: (B*S*H, D)
|
||||
x_flat = x.reshape(-1, D).contiguous()
|
||||
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||
# by dividing the row_idx by H to get the cos/sin row
|
||||
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
|
||||
# all sequences share the same rotary positions). The kernel needs a
|
||||
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
|
||||
# onto a single (b, s) pair, so expand-then-contiguous to materialize
|
||||
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
|
||||
if cos.shape[0] != B:
|
||||
if cos.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
|
||||
f"to x batch dim ({B})"
|
||||
)
|
||||
cos = cos.expand(B, S, n_rot)
|
||||
sin = sin.expand(B, S, n_rot)
|
||||
cos_flat = cos.reshape(B * S, n_rot).contiguous()
|
||||
sin_flat = sin.reshape(B * S, n_rot).contiguous()
|
||||
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(
|
||||
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
|
||||
)
|
||||
return y_flat.view(shape)
|
||||
|
||||
|
||||
|
||||
@@ -156,13 +156,20 @@ class PatchManager:
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
|
||||
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
|
||||
# override needs to walk the raw model tree, which is broken by
|
||||
# the post-load PEFT wrapping. The accompanying
|
||||
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
|
||||
# installation-time-independent, so both halves of the fix live
|
||||
# cleanly in the same call even though one is instance-scoped
|
||||
# and the other is module-scoped.
|
||||
self._apply_gemma_hybrid_attention(model)
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
self._apply_unsloth_patches(model)
|
||||
self._apply_lora_kernel_patch(model)
|
||||
self._apply_scaling_softmax_patch(model)
|
||||
|
||||
@@ -173,12 +180,23 @@ class PatchManager:
|
||||
which exceeds flash attention's supported size. This patch loads the model
|
||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||
|
||||
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
|
||||
which fixes the corresponding mask construction inside
|
||||
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
|
||||
override is not enough — the forward still builds a 2D FA2-format mask
|
||||
at the model level and the SDPA layers crash at long context lengths
|
||||
with ``RuntimeError: The expanded size of the tensor ... must match``.
|
||||
"""
|
||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||
return
|
||||
|
||||
import copy
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
# Navigate to the module that has 'layers' - varies by model structure:
|
||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||
@@ -392,11 +410,28 @@ class PatchManager:
|
||||
patch_qwen3_5_vlm_flash_attention()
|
||||
|
||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
# The fused attn path is now compatible with
|
||||
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
|
||||
# rotary (cos.shape[-1] < head_dim) and the fused forward
|
||||
# mirrors the current ``Gemma4TextAttention.forward`` API
|
||||
# for shared kv (read from / write to
|
||||
# ``past_key_values.shared_layers``). See
|
||||
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
|
||||
# for the history.
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
# Shared-KV side channel when activation checkpointing (PR #3611).
|
||||
fsdp_cfg = self.cfg.fsdp_config
|
||||
needs_shared_kv_workaround = (not self.inference) and bool(
|
||||
self.cfg.gradient_checkpointing
|
||||
or self.cfg.activation_offloading
|
||||
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
|
||||
)
|
||||
patch_gemma4_fused_attn(
|
||||
install_shared_kv_workaround=needs_shared_kv_workaround
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fix_nemotron_h_conversion_mapping():
|
||||
@@ -674,24 +709,10 @@ class PatchManager:
|
||||
)
|
||||
|
||||
patch_fa_llama_cross_entropy()
|
||||
elif self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
|
||||
|
||||
patch_llama_rms_norm()
|
||||
elif self.cfg.unsloth_rms_norm:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def _patch_llama_flash_attention(self):
|
||||
"""Apply Flash Attention patches for LLaMA models."""
|
||||
@@ -758,23 +779,6 @@ class PatchManager:
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
def _apply_unsloth_patches(self, model):
|
||||
"""Apply unsloth optimization patches."""
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
integrate_lora_mlp_patch(peft_model=model)
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
||||
|
||||
integrate_lora_patch(peft_model=model, cfg=self.cfg)
|
||||
|
||||
if self.cfg.unsloth_rope:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
||||
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def _apply_lora_kernel_patch(self, model):
|
||||
"""Apply LoRA kernel patches."""
|
||||
if (
|
||||
|
||||
@@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs = {}
|
||||
if cfg.revision_of_model:
|
||||
processor_kwargs["revision"] = cfg.revision_of_model
|
||||
if cfg.processor_kwargs:
|
||||
processor_kwargs.update(cfg.processor_kwargs)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
|
||||
|
||||
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Hybrid attention mask fix for Gemma 4.
|
||||
|
||||
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
|
||||
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
|
||||
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
|
||||
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
|
||||
``self_attn.config``, leaving sliding-window layers on FA2.
|
||||
|
||||
The per-layer config override alone is insufficient, however:
|
||||
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
|
||||
using the **model-level** config and passes the mapped mask to each
|
||||
decoder layer. With FA2 still set at the model level, the ``full_attention``
|
||||
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
|
||||
The global layers then fail with::
|
||||
|
||||
RuntimeError: The expanded size of the tensor (S) must match the existing
|
||||
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
|
||||
sizes: [B, S]
|
||||
|
||||
...when the sequence length grows past roughly 7k tokens.
|
||||
|
||||
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
|
||||
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
|
||||
the original in ``masking_utils``. The wrapper forces
|
||||
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
|
||||
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
|
||||
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
|
||||
alone, so sliding-window layers continue to receive FA2-format masks.
|
||||
|
||||
The patch is idempotent. Install once per process, before any Gemma 4
|
||||
forward pass runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_PATCH_APPLIED = False
|
||||
|
||||
|
||||
def patch_gemma4_hybrid_mask() -> bool:
|
||||
"""Install the Gemma 4 hybrid-attention mask fix.
|
||||
|
||||
Returns ``True`` if the patch was installed (or was already installed),
|
||||
``False`` if the target module could not be imported (e.g. transformers
|
||||
version predates Gemma 4) — in which case nothing is done and the
|
||||
caller can continue unaffected.
|
||||
"""
|
||||
global _PATCH_APPLIED
|
||||
if _PATCH_APPLIED:
|
||||
return True
|
||||
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
LOG.debug(
|
||||
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
|
||||
"skipping. This is fine for non-Gemma4 training."
|
||||
)
|
||||
return False
|
||||
|
||||
if not hasattr(modeling_gemma4, "create_causal_mask"):
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
|
||||
"binding, skipping. Transformers API may have changed."
|
||||
)
|
||||
return False
|
||||
|
||||
original = modeling_gemma4.create_causal_mask
|
||||
|
||||
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
|
||||
"""Wrapper that forces SDPA format for the full-attention mask.
|
||||
|
||||
The global layers were patched to SDPA by
|
||||
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
|
||||
original ``create_causal_mask`` dispatches on
|
||||
``config._attn_implementation``; we shadow that with a local
|
||||
override.
|
||||
"""
|
||||
sdpa_config = copy.copy(config)
|
||||
sdpa_config._attn_implementation = "sdpa"
|
||||
return original(sdpa_config, *args, **kwargs)
|
||||
|
||||
# Preserve the original reference on the wrapper for tests / teardown.
|
||||
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
|
||||
|
||||
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
|
||||
_PATCH_APPLIED = True
|
||||
LOG.info(
|
||||
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
|
||||
"force SDPA-format masks for full-attention layers"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def unpatch_gemma4_hybrid_mask() -> None:
|
||||
"""Restore the original ``create_causal_mask``. Useful for tests."""
|
||||
global _PATCH_APPLIED
|
||||
if not _PATCH_APPLIED:
|
||||
return
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
_PATCH_APPLIED = False
|
||||
return
|
||||
current = modeling_gemma4.create_causal_mask
|
||||
original = getattr(current, "_axolotl_original", None)
|
||||
if original is not None:
|
||||
modeling_gemma4.create_causal_mask = original
|
||||
_PATCH_APPLIED = False
|
||||
@@ -6,15 +6,29 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro
|
||||
|
||||
Usage:
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||
patch_gemma4_fused_attn()
|
||||
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
|
||||
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
|
||||
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
|
||||
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
|
||||
|
||||
|
||||
def _set_shared_kv_states(store):
|
||||
_GEMMA4_SHARED_KV_STORE["store"] = store
|
||||
|
||||
|
||||
def _get_shared_kv_states():
|
||||
return _GEMMA4_SHARED_KV_STORE["store"]
|
||||
|
||||
|
||||
def _make_fused_forward(original_forward):
|
||||
@@ -30,7 +44,7 @@ def _make_fused_forward(original_forward):
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
@@ -39,6 +53,10 @@ def _make_fused_forward(original_forward):
|
||||
eager_attention_forward,
|
||||
)
|
||||
|
||||
store = _get_shared_kv_states()
|
||||
if store is not None:
|
||||
shared_kv_states = store
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
eps = self.config.rms_norm_eps
|
||||
@@ -133,15 +151,44 @@ def _make_fused_forward(original_forward):
|
||||
return fused_forward
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn():
|
||||
def _patch_decoder_layer_call():
|
||||
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
|
||||
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
|
||||
|
||||
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
|
||||
return
|
||||
|
||||
original_call = Gemma4TextDecoderLayer.__call__
|
||||
|
||||
def patched_call(self, *args, **kwargs):
|
||||
shared_kv = kwargs.pop("shared_kv_states", None)
|
||||
# Overwrite unconditionally (including with None) so a previous step's
|
||||
# dict cannot leak into a later call without shared_kv_states (PR #3611).
|
||||
_set_shared_kv_states(shared_kv)
|
||||
return original_call(self, *args, **kwargs)
|
||||
|
||||
Gemma4TextDecoderLayer.__call__ = patched_call
|
||||
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
|
||||
and optionally route `shared_kv_states` via a module-level side channel to
|
||||
avoid a VRAM leak under activation checkpointing (PR #3611).
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
original_forward = Gemma4TextAttention.forward
|
||||
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||
|
||||
if install_shared_kv_workaround:
|
||||
_patch_decoder_layer_call()
|
||||
|
||||
logger.info(
|
||||
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||
)
|
||||
if install_shared_kv_workaround:
|
||||
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")
|
||||
|
||||
@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
|
||||
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
|
||||
# language-side module is separated from the vision tower. Try
|
||||
# both names before giving up.
|
||||
mlp_cls = getattr(
|
||||
module,
|
||||
f"{model_cls_prefix}MLP",
|
||||
None,
|
||||
) or getattr(module, f"{model_cls_prefix}TextMLP")
|
||||
|
||||
if use_original_mlp:
|
||||
mlp_forward = mlp_cls.forward
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
"""module for patching with unsloth optimizations"""
|
||||
|
||||
import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
""".lstrip("\n")
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
||||
""".lstrip("\n")
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip("\n")
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(self, attn_output)
|
||||
""".lstrip("\n")
|
||||
|
||||
|
||||
def original_apply_qkv(self, hidden_states):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self, hidden_states):
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_self_attn_code() -> str:
|
||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||
return forward
|
||||
|
||||
|
||||
def check_self_attn_is_patchable() -> bool:
|
||||
qkv = get_self_attn_code()
|
||||
qkv, _ = detab_code(qkv)
|
||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
||||
|
||||
def UnslothForCausalLMLoss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
||||
)
|
||||
return loss
|
||||
|
||||
if model_type == "llama":
|
||||
from transformers.loss import loss_utils
|
||||
|
||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
self_attn_lora_patched = False
|
||||
|
||||
|
||||
def patch_self_attn_lora():
|
||||
global self_attn_lora_patched
|
||||
if self_attn_lora_patched:
|
||||
# prevent patching multiple times
|
||||
return
|
||||
self_attn_forward = get_self_attn_code()
|
||||
LlamaFlashAttention2._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def unsloth_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec(
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals())
|
||||
self_attn_lora_patched = True
|
||||
LOG.info("patching unsloth attn lora")
|
||||
LlamaFlashAttention2.forward = unsloth_attn_forward
|
||||
|
||||
|
||||
def integrate_rope_embeddings():
|
||||
import transformers.models.llama.modeling_llama
|
||||
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
position_ids=None,
|
||||
unsqueeze_dim=1,
|
||||
):
|
||||
return fast_rope_embedding(q, k, cos, sin)
|
||||
|
||||
LOG.info("patching unsloth RoPE embeddings")
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
|
||||
|
||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||
|
||||
apply_lora_mlp = apply_lora_mlp_swiglu
|
||||
elif peft_model.base_model.config.model_type == "gemma":
|
||||
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
||||
|
||||
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {peft_model.base_model.config.model_type} not supported"
|
||||
)
|
||||
|
||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||
layer_modules = [
|
||||
getattr(layer.mlp, linear_proj)
|
||||
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
||||
]
|
||||
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
mlp_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||
else:
|
||||
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
|
||||
|
||||
|
||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
||||
|
||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||
if cfg.unsloth_lora_qkv:
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
qkv_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||
else:
|
||||
layer.self_attn.apply_qkv = original_apply_qkv
|
||||
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
|
||||
if cfg.unsloth_lora_o:
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
o_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_o_lora and o_no_bias and o_not_dora:
|
||||
layer.self_attn.apply_o = apply_lora_o
|
||||
else:
|
||||
layer.self_attn.apply_o = original_apply_o
|
||||
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
|
||||
|
||||
|
||||
def patch_unsloth_layernorm():
|
||||
try:
|
||||
import transformers.models.llama.modeling_llama
|
||||
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
"""LlamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return Fast_RMS_Layernorm.apply(
|
||||
hidden_states, self.weight, self.variance_epsilon, False
|
||||
)
|
||||
|
||||
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.warning("missing unsloth library")
|
||||
@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
|
||||
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -1004,6 +1004,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
# Some datasets have tools set to str
|
||||
if isinstance(tools, str):
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
|
||||
raise
|
||||
if isinstance(tools, list):
|
||||
# Process each tool to handle JSON string parameters
|
||||
for tool in tools:
|
||||
@@ -1034,6 +1041,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if messages is None:
|
||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||
|
||||
if isinstance(messages, str):
|
||||
try:
|
||||
messages = json.loads(messages)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(f"Error parsing messages as JSON. Error: {e}")
|
||||
raise
|
||||
assert isinstance(messages, list), (
|
||||
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
|
||||
)
|
||||
|
||||
# Extra check here to make sure decoded json is a list of dicts.
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, dict), (
|
||||
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
|
||||
)
|
||||
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
|
||||
@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
# Serializes access to the worker pipe. The underlying
|
||||
# multiprocessing.Connection is a single full-duplex stream shared
|
||||
# across all HTTP handlers; concurrent requests interleave bytes on
|
||||
# the wire and corrupt the pickle framing (seen as
|
||||
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
|
||||
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
|
||||
# the round-trip so only one inflight call at a time per pipe.
|
||||
worker_pipe_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA-specific endpoints
|
||||
# ------------------------------------------------------------------
|
||||
@@ -631,6 +640,150 @@ def main(script_args: ScriptArguments):
|
||||
},
|
||||
}
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions(request_body: dict):
|
||||
"""OpenAI-compatible text-completions endpoint.
|
||||
|
||||
Accepts either a string ``prompt`` or a list-of-int
|
||||
``prompt_token_ids`` (as the text-completions spec allows). Routes
|
||||
to the internal vLLM generate method with the active LoRA adapter
|
||||
and returns an OpenAI /v1/completions-shaped response including
|
||||
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
|
||||
``generation_log_probs`` for NeMo Gym agents that need raw
|
||||
tokens + logprobs.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
prompt_raw = request_body.get("prompt")
|
||||
temperature = request_body.get("temperature", 1.0)
|
||||
max_tokens = request_body.get("max_tokens", 512)
|
||||
top_p = request_body.get("top_p", 1.0)
|
||||
n = request_body.get("n", 1)
|
||||
logprobs = request_body.get("logprobs") or 0
|
||||
stop_token_ids = request_body.get("stop_token_ids") or None
|
||||
|
||||
# Accept either a string or a list[int] token id prompt. Lists
|
||||
# must contain ints only (raise on lists of strings so callers get
|
||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||
# rare case callers pass a single-prompt batch.
|
||||
if (
|
||||
isinstance(prompt_raw, list)
|
||||
and prompt_raw
|
||||
and isinstance(prompt_raw[0], list)
|
||||
):
|
||||
prompt_raw = prompt_raw[0]
|
||||
|
||||
prompt_dict: dict[str, Any] = {}
|
||||
if isinstance(prompt_raw, list):
|
||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||
elif isinstance(prompt_raw, str):
|
||||
prompt_dict = {"prompt": prompt_raw}
|
||||
else:
|
||||
return {
|
||||
"error": {
|
||||
"message": ("prompt must be a string or a list of token ids"),
|
||||
"type": "invalid_request",
|
||||
}
|
||||
}
|
||||
|
||||
generation_kwargs: dict[str, Any] = {
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
if stop_token_ids:
|
||||
generation_kwargs["stop_token_ids"] = stop_token_ids
|
||||
sampling_params = SamplingParams(
|
||||
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
|
||||
|
||||
# Hold the pipe lock across send+recv — concurrent requests would
|
||||
# otherwise interleave pickle frames on the worker connection.
|
||||
async with worker_pipe_lock:
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [{"prompt": "<placeholder>"}]
|
||||
kwargs = {
|
||||
"prompts": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||
)
|
||||
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
for o in all_outputs:
|
||||
if isinstance(o, dict) and "error" in o:
|
||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
if not all_outputs:
|
||||
return {"choices": [], "model": script_args.model}
|
||||
|
||||
choices = []
|
||||
for i, output in enumerate(all_outputs):
|
||||
for j, out in enumerate(output.outputs):
|
||||
text = out.text
|
||||
# OpenAI-style `logprobs` block for text-completions:
|
||||
# { "tokens": [...], "token_logprobs": [...] }
|
||||
lp_block = None
|
||||
if out.logprobs:
|
||||
tokens_str: list[str] = []
|
||||
token_lps: list[float] = []
|
||||
for step in out.logprobs:
|
||||
chosen = next(iter(step.values()))
|
||||
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
|
||||
token_lps.append(float(chosen.logprob))
|
||||
lp_block = {
|
||||
"tokens": tokens_str,
|
||||
"token_logprobs": token_lps,
|
||||
}
|
||||
|
||||
choice = {
|
||||
"index": i * n + j,
|
||||
"text": text,
|
||||
"finish_reason": "stop"
|
||||
if out.finish_reason == "stop"
|
||||
else "length",
|
||||
"logprobs": lp_block,
|
||||
# NeMo-Gym / retrace agent extras — preserved on the
|
||||
# choice so callers with raw-token pipelines don't
|
||||
# have to re-tokenize.
|
||||
"prompt_token_ids": output.prompt_token_ids,
|
||||
"generation_token_ids": list(out.token_ids),
|
||||
"generation_log_probs": (
|
||||
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
|
||||
if out.logprobs
|
||||
else []
|
||||
),
|
||||
}
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
||||
completion_tokens = sum(
|
||||
len(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
)
|
||||
|
||||
return {
|
||||
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "text_completion",
|
||||
"model": script_args.model,
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
|
||||
@@ -309,6 +309,16 @@ class AxolotlInputConfig(
|
||||
|
||||
dpo_padding_free: bool | None = None
|
||||
|
||||
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of DPO losses to use."},
|
||||
)
|
||||
|
||||
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Weights for each DPO loss."},
|
||||
)
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
list[
|
||||
@@ -823,13 +833,6 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
unsloth_lora_o: bool | None = None
|
||||
unsloth_rms_norm: bool | None = None
|
||||
unsloth_rope: bool | None = None
|
||||
|
||||
lora_mlp_kernel: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1469,21 +1472,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
@@ -1537,8 +1525,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
# RL trainers not tested so don't enable kernels by default
|
||||
return data
|
||||
if data.get("adapter") in ["lora", "qlora"]:
|
||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
# Skip if already set or using 8-bit
|
||||
kernel_fields = [
|
||||
"lora_mlp_kernel",
|
||||
"lora_qkv_kernel",
|
||||
@@ -1547,7 +1534,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
]
|
||||
if (
|
||||
any(data.get(k) is not None for k in kernel_fields)
|
||||
or any(data.get(k) for k in unsloth_fields)
|
||||
or data.get("adapter") == "lora"
|
||||
and data.get("load_in_8bit")
|
||||
):
|
||||
|
||||
@@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
|
||||
processor_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
|
||||
},
|
||||
)
|
||||
tokenizer_save_jinja_files: bool | None = Field(
|
||||
default=True, # match the default behavior from transformers
|
||||
json_schema_extra={
|
||||
@@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel):
|
||||
)
|
||||
return trust_remote_code
|
||||
|
||||
@field_validator("processor_kwargs")
|
||||
@classmethod
|
||||
def reject_reserved_processor_kwargs(cls, processor_kwargs):
|
||||
if not processor_kwargs:
|
||||
return processor_kwargs
|
||||
reserved = {"revision", "trust_remote_code"}
|
||||
conflicts = reserved.intersection(processor_kwargs)
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
"Do not set reserved keys "
|
||||
f"{sorted(conflicts)} inside `processor_kwargs`; "
|
||||
"use the top-level `revision_of_model` / `trust_remote_code` "
|
||||
"config keys instead."
|
||||
)
|
||||
return processor_kwargs
|
||||
|
||||
|
||||
class ModelOutputConfig(BaseModel):
|
||||
"""model save configuration subset"""
|
||||
|
||||
@@ -52,6 +52,26 @@ class DatasetValidationMixin:
|
||||
|
||||
return datasets
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_unsloth_fields(cls, data):
|
||||
deprecated_fields = [
|
||||
"unsloth_cross_entropy_loss",
|
||||
"unsloth_lora_mlp",
|
||||
"unsloth_lora_qkv",
|
||||
"unsloth_lora_o",
|
||||
"unsloth_rms_norm",
|
||||
"unsloth_rope",
|
||||
]
|
||||
found = [f for f in deprecated_fields if data.get(f)]
|
||||
if found:
|
||||
raise ValueError(
|
||||
f"`{'`, `'.join(found)}` {'has' if len(found) == 1 else 'have'} been removed. "
|
||||
"Please use `lora_mlp_kernel`, `lora_qkv_kernel`, `lora_o_kernel` instead. "
|
||||
"See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dataset_or_pretraining_dataset(cls, data):
|
||||
@@ -558,6 +578,11 @@ class TrainingValidationMixin:
|
||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
if data.get("processor_kwargs"):
|
||||
raise ValueError(
|
||||
"processor_kwargs is not supported with mistral-common tokenizer"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -607,36 +632,6 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_qlora_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
is_lora_kernel = any(
|
||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
is_unsloth_lora = any(
|
||||
data.get(k)
|
||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
)
|
||||
if is_lora_kernel and is_unsloth_lora:
|
||||
raise ValueError(
|
||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fused_lora(self):
|
||||
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
|
||||
@@ -770,6 +765,122 @@ class RLValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dpo(cls, data):
|
||||
dpo_loss_type = data.get("dpo_loss_type")
|
||||
dpo_loss_weights = data.get("dpo_loss_weights")
|
||||
rl = data.get("rl")
|
||||
|
||||
if rl == "ipo":
|
||||
LOG.warning(
|
||||
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
|
||||
)
|
||||
|
||||
if rl == "dpo":
|
||||
if dpo_loss_weights is not None and dpo_loss_type is None:
|
||||
raise ValueError(
|
||||
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
|
||||
)
|
||||
if (
|
||||
dpo_loss_type is not None
|
||||
and dpo_loss_weights is not None
|
||||
and len(dpo_loss_type) != len(dpo_loss_weights)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
|
||||
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
|
||||
)
|
||||
elif dpo_loss_type is not None or dpo_loss_weights is not None:
|
||||
raise ValueError(
|
||||
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
|
||||
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_batch_size_divisibility(cls, data):
|
||||
"""Surface GRPO batch-shape mismatches at config-parse time.
|
||||
|
||||
TRL's GRPOTrainer requires that the per-step generation batch size be
|
||||
evenly divisible by ``num_generations`` so that every prompt can be
|
||||
replicated exactly ``num_generations`` times. The runtime check inside
|
||||
``GRPOTrainer.__init__`` only fires after the model has been loaded —
|
||||
too late and too cryptic for the user. We replicate the check here so
|
||||
the failure is immediate and actionable.
|
||||
|
||||
Also enforces:
|
||||
- ``num_generations >= 2`` (group-relative advantage needs variance)
|
||||
- ``effective_gbs >= num_generations * world_size`` when capabilities
|
||||
indicate multiple ranks (each rank needs at least one full group)
|
||||
"""
|
||||
if data.get("rl") != "grpo":
|
||||
return data
|
||||
|
||||
trl_cfg = data.get("trl") or {}
|
||||
num_gen = trl_cfg.get("num_generations")
|
||||
if num_gen is None:
|
||||
# TRL's own default is 8 — but if the user didn't set it, we
|
||||
# don't have enough info to validate anything. Let TRL's own
|
||||
# init handle the default-vs-batch interaction.
|
||||
return data
|
||||
if num_gen < 2:
|
||||
raise ValueError(
|
||||
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
|
||||
"With num_generations=1, every group has zero advantage and "
|
||||
"the policy never updates."
|
||||
)
|
||||
|
||||
explicit_gbs = trl_cfg.get("generation_batch_size")
|
||||
if explicit_gbs is not None:
|
||||
effective_gbs = int(explicit_gbs)
|
||||
gbs_source = "trl.generation_batch_size"
|
||||
else:
|
||||
mb = data.get("micro_batch_size") or 1
|
||||
ga = data.get("gradient_accumulation_steps") or 1
|
||||
effective_gbs = int(mb) * int(ga)
|
||||
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
|
||||
|
||||
if effective_gbs % num_gen != 0:
|
||||
# Suggest the smallest GA bump that fixes it for the common case
|
||||
# where the user hasn't set generation_batch_size explicitly.
|
||||
hint = ""
|
||||
if explicit_gbs is None:
|
||||
from math import gcd
|
||||
|
||||
mb_val = int(data.get("micro_batch_size") or 1)
|
||||
# smallest GA such that mb*GA is a multiple of num_gen
|
||||
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
|
||||
suggested_ga = lcm // mb_val
|
||||
hint = (
|
||||
f" Smallest fix: set `gradient_accumulation_steps: "
|
||||
f"{suggested_ga}` (so micro_batch_size * GA = "
|
||||
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
|
||||
)
|
||||
raise ValueError(
|
||||
f"GRPO: generation batch size must be divisible by "
|
||||
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
|
||||
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
|
||||
)
|
||||
|
||||
# Multi-rank check: each rank must receive at least one full group
|
||||
# per step. Without `capabilities` populated yet (mode='before'), we
|
||||
# fall back to user-set distributed fields.
|
||||
world_size = (
|
||||
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
|
||||
)
|
||||
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
|
||||
raise ValueError(
|
||||
f"GRPO with world_size={world_size} requires effective_gbs "
|
||||
f">= num_generations * world_size = {num_gen * world_size}, "
|
||||
f"got {effective_gbs}. Increase gradient_accumulation_steps "
|
||||
f"or micro_batch_size."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class OptimizationValidationMixin:
|
||||
"""Validation methods related to optimization and performance."""
|
||||
@@ -860,17 +971,6 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_xentropy_patch_conflicts(cls, data):
|
||||
if data.get("flash_attn_cross_entropy") and data.get(
|
||||
"unsloth_cross_entropy_loss"
|
||||
):
|
||||
raise ValueError(
|
||||
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cross_entropy_conflicts(cls, data):
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
dynamic requirements for axolotl
|
||||
"""
|
||||
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools.command.build_py import build_py as _build_py
|
||||
|
||||
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
class BuildPyCommand(_build_py):
|
||||
"""
|
||||
custom build_py command to parse dynamic requirements
|
||||
"""
|
||||
|
||||
def finalize_options(self):
|
||||
super().finalize_options()
|
||||
install_requires, _ = parse_requirements()
|
||||
self.distribution.install_requires = install_requires
|
||||
@@ -119,15 +119,49 @@ def download_smollm2_135m_gptq_model():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen_2_5_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||
def download_qwen3_half_billion_model():
|
||||
# download the model (still used as the KD teacher in tests/e2e/integrations/test_kd.py)
|
||||
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen3_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
||||
def download_tiny_llama_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-llama-50m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_mistral_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-mistral-25m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_mixtral_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-mixtral-30m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_phi_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-phi-64m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_falcon_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-falcon-42m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_qwen2_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen2-129m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_qwen3_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen3-129m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_gemma2_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-gemma2-137m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -325,10 +359,10 @@ def download_phi_4_reasoning_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_phi_3_medium_model_fixture():
|
||||
def download_phi_3_mini_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"microsoft/Phi-3-medium-128k-instruct",
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
@@ -620,7 +654,15 @@ def fixture_min_base_cfg():
|
||||
)
|
||||
def test_load_fixtures(
|
||||
download_smollm2_135m_model,
|
||||
download_qwen_2_5_half_billion_model,
|
||||
download_qwen3_half_billion_model,
|
||||
download_tiny_llama_model,
|
||||
download_tiny_mistral_model,
|
||||
download_tiny_mixtral_model,
|
||||
download_tiny_phi_model,
|
||||
download_tiny_falcon_model,
|
||||
download_tiny_qwen2_model,
|
||||
download_tiny_qwen3_model,
|
||||
download_tiny_gemma2_model,
|
||||
download_tatsu_lab_alpaca_dataset,
|
||||
download_mhenrichsen_alpaca_2k_dataset,
|
||||
download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||
|
||||
@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
class TestVllmLoraSyncPatch(unittest.TestCase):
|
||||
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
|
||||
|
||||
These tests exercise the patch-installation branch in isolation. They build
|
||||
a stub trainer with just enough attributes to look like
|
||||
``AsyncGRPOTrainer`` for the duration of the relevant code path.
|
||||
|
||||
Background — there are two correct behaviors and we historically had a bug
|
||||
where both modes used the same one:
|
||||
|
||||
- Async prefetch ON: the BG generation thread can't safely call
|
||||
sync_weights mid-rollout. We no-op the stock hook and drive sync from
|
||||
the main thread via ``_maybe_sync_vllm_weights``.
|
||||
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
|
||||
calls ``sync_weights`` once per step boundary on the main thread. We
|
||||
wire that hook directly to ``_sync_lora_adapter`` because
|
||||
``_maybe_sync_vllm_weights`` short-circuits when async is off.
|
||||
|
||||
Before the fix, both modes installed ``lambda: None``, so sync mode never
|
||||
pushed any LoRA adapter to vLLM and the trainer was a no-op.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.vllm_lora_sync = vllm_lora_sync
|
||||
args.async_prefetch = async_prefetch
|
||||
|
||||
class FakeVllmGen:
|
||||
sync_weights = staticmethod(lambda: None)
|
||||
model = MagicMock()
|
||||
|
||||
# Use object.__new__ so we don't run __init__ (which needs a real
|
||||
# model, dataset, etc.). We only need the `_generate_single_turn`
|
||||
# method's patch branch to run, so we set up the minimum state.
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.vllm_generation = FakeVllmGen()
|
||||
trainer._patched_sync_weights = False
|
||||
# Spy on _sync_lora_adapter so we can assert it's the function the
|
||||
# hook delegates to in sync mode.
|
||||
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
|
||||
trainer._sync_peft_weights_no_merge = MagicMock(
|
||||
name="_sync_peft_weights_no_merge_spy"
|
||||
)
|
||||
return trainer
|
||||
|
||||
@staticmethod
|
||||
def _run_patch_branch(trainer):
|
||||
"""Execute just the sync_weights-patching branch in isolation.
|
||||
|
||||
We can't easily call the real ``_generate_single_turn`` because it
|
||||
does a full vLLM generate. Instead we copy the exact branch out of
|
||||
the source so the test verifies the same logic the trainer runs.
|
||||
"""
|
||||
if not getattr(trainer, "_patched_sync_weights", False):
|
||||
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
|
||||
if getattr(trainer.args, "vllm_lora_sync", False):
|
||||
if getattr(trainer.args, "async_prefetch", False):
|
||||
trainer.vllm_generation.sync_weights = lambda: None
|
||||
else:
|
||||
sync_helper = trainer._sync_lora_adapter
|
||||
|
||||
def _lora_filesystem_sync():
|
||||
sync_helper()
|
||||
|
||||
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
|
||||
trainer._patched_sync_weights = True
|
||||
|
||||
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Trigger the patched hook — it must call _sync_lora_adapter.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_lora_sync_installs_noop_hook(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
|
||||
self._run_patch_branch(trainer)
|
||||
|
||||
assert trainer._patched_sync_weights is True
|
||||
# Hook must be a no-op so BG-thread generation doesn't fight the
|
||||
# main-thread optimizer step over the model weights.
|
||||
trainer.vllm_generation.sync_weights()
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
|
||||
"""Installing the patch should not pre-emptively sync."""
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
# _sync_lora_adapter should only be called when the patched hook
|
||||
# itself is invoked (e.g., from TRL's _generate_single_turn).
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
def test_patch_is_idempotent(self):
|
||||
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
|
||||
self._run_patch_branch(trainer)
|
||||
first_hook = trainer.vllm_generation.sync_weights
|
||||
# Second call must not re-patch (otherwise we'd lose the original).
|
||||
self._run_patch_branch(trainer)
|
||||
assert trainer.vllm_generation.sync_weights is first_hook
|
||||
|
||||
|
||||
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
|
||||
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
|
||||
|
||||
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
|
||||
on the very first call when ``vllm_sync_interval`` was ``None`` (which
|
||||
is the default for any config that doesn't explicitly set it). We now
|
||||
fall back to interval=1 so unset means "sync every step", matching the
|
||||
behavior of TRL's own ``_generate_single_turn``.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_stub_trainer(interval, async_prefetch):
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
class FakeArgs:
|
||||
pass
|
||||
|
||||
args = FakeArgs()
|
||||
args.async_prefetch = async_prefetch
|
||||
args.vllm_sync_interval = interval
|
||||
args.vllm_lora_sync = True
|
||||
|
||||
class FakeState:
|
||||
global_step = 1
|
||||
|
||||
trainer = object.__new__(AsyncGRPOTrainer)
|
||||
trainer.args = args
|
||||
trainer.use_vllm = True
|
||||
trainer.state = FakeState()
|
||||
trainer._last_synced_step = 0
|
||||
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
|
||||
return trainer
|
||||
|
||||
def test_interval_none_in_async_mode_does_not_crash(self):
|
||||
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# Should not raise TypeError — defaults to every-step sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_sync_mode_drives_sync(self):
|
||||
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
|
||||
|
||||
The previous behavior (early return when ``not async_prefetch``)
|
||||
assumed TRL's stock ``_generate_single_turn`` would handle sync.
|
||||
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
|
||||
where the data producer bypasses ``_generate_single_turn``
|
||||
entirely. Without this trigger no sync ever happens and the
|
||||
trainer becomes a no-op.
|
||||
"""
|
||||
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
def test_async_mode_with_explicit_interval_respects_modulo(self):
|
||||
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOTrainer,
|
||||
)
|
||||
|
||||
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_not_called()
|
||||
|
||||
# global_step=4 → 4 % 4 == 0 → sync
|
||||
trainer.state.global_step = 4
|
||||
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
|
||||
trainer._sync_lora_adapter.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 0.5],
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
@@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.IPO,
|
||||
"rl": RLType.DPO,
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"dpo_label_smoothing": 0,
|
||||
"beta": 0.1,
|
||||
}
|
||||
@@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
||||
assert training_arguments.loss_weights == [1.0, 0.5]
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
@@ -10,7 +10,10 @@ from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
from tests.e2e.utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -35,13 +38,16 @@ def min_cfg(temp_dir):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 5e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"max_steps": 40,
|
||||
"warmup_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,11 +70,18 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=2.2,
|
||||
max_final=2.0,
|
||||
)
|
||||
|
||||
def test_qwen2_w_cce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"plugins": [
|
||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||
],
|
||||
@@ -87,13 +100,15 @@ class TestCutCrossEntropyIntegration:
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"max_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -108,6 +123,13 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
@@ -136,3 +158,10 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=2.2,
|
||||
max_final=2.0,
|
||||
)
|
||||
|
||||
@@ -54,24 +54,8 @@ except (ImportError, ModuleNotFoundError):
|
||||
)
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
|
||||
smoe_A = torch.zeros(
|
||||
rank * num_experts,
|
||||
K_inter,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
smoe_B = torch.zeros(
|
||||
N_hidden,
|
||||
rank * num_experts,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
s = e * rank
|
||||
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
|
||||
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
|
||||
smoe_A = peft_A
|
||||
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
return smoe_A, smoe_B
|
||||
|
||||
def _unwrap_experts_lora(experts_module):
|
||||
@@ -322,12 +306,14 @@ class TestLoRABLayoutConversion:
|
||||
hidden, inter = 32, 16
|
||||
scaling = 2.0
|
||||
|
||||
peft_A = torch.randn(E * r, hidden)
|
||||
peft_B = torch.randn(inter, E * r)
|
||||
# peft >=0.19.1 for down_proj [E, hidden, inter]:
|
||||
# swaps in/out, lora_A [r*E, inter], lora_B [hidden, r*E]
|
||||
peft_A = torch.randn(E * r, inter)
|
||||
peft_B = torch.randn(hidden, E * r)
|
||||
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
A_r = peft_A.reshape(E, r, inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
for e in range(E):
|
||||
@@ -339,31 +325,22 @@ class TestLoRABLayoutConversion:
|
||||
)
|
||||
|
||||
def test_gate_up_proj_conversion(self):
|
||||
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
||||
"""Verify gate_up_proj LoRA conversion with non-square dims.
|
||||
|
||||
gate_up_proj param: [E, 2*inter, hidden].
|
||||
peft: in_features=2*inter, out_features=hidden.
|
||||
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
|
||||
|
||||
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
||||
peft swaps in/out for 3D: lora_A [r*E, hidden], lora_B [2*inter, r*E].
|
||||
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
||||
|
||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
|
||||
"""
|
||||
E, r = 4, 2
|
||||
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
||||
scaling = 2.0
|
||||
|
||||
# peft assigns: in_features=2*inter, out_features=hidden
|
||||
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
|
||||
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
|
||||
peft_A = torch.randn(E * r, hidden) # [r*E, in=hidden]
|
||||
peft_B = torch.randn(2 * inter, E * r) # [out=2*inter, r*E]
|
||||
|
||||
# peft delta via einsum: "o r e, e r i -> e i o"
|
||||
A_r = peft_A.reshape(E, r, 2 * inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
|
||||
# = param[e] shape [2*inter, hidden]
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(2 * inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
|
||||
@@ -421,23 +398,21 @@ class TestPeftLoRAWeightExtraction:
|
||||
r,
|
||||
)
|
||||
|
||||
# gate_up_proj [E, 2*inter, hidden]
|
||||
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
|
||||
# gate_up_proj [E, 2*inter, hidden] — peft swaps in/out for 3D
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
||||
].shape == (E * r, 2 * config.intermediate_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||
].shape == (config.hidden_size, E * r)
|
||||
|
||||
# down_proj [E, hidden, inter]
|
||||
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_A.default.weight"
|
||||
].shape == (E * r, config.hidden_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||
].shape == (2 * config.intermediate_size, E * r)
|
||||
|
||||
# down_proj [E, hidden, inter] — peft swaps in/out for 3D
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_A.default.weight"
|
||||
].shape == (E * r, config.intermediate_size)
|
||||
assert trainable[
|
||||
"base_model.model.moe.experts.lora_B.default.weight"
|
||||
].shape == (config.intermediate_size, E * r)
|
||||
].shape == (config.hidden_size, E * r)
|
||||
|
||||
@requires_cuda
|
||||
def test_peft_forward_runs(self):
|
||||
@@ -488,8 +463,7 @@ class TestPeftLoRAWeightExtraction:
|
||||
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
||||
assert down_lora is not None, "down_proj LoRA not detected"
|
||||
|
||||
# Check shapes (after peft->scattermoe conversion with A<->B swap)
|
||||
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
|
||||
# gate_up_proj: K=hidden, N=2*inter
|
||||
E, r = config.num_experts, 4
|
||||
gup_A, gup_B, gup_s = gup_lora
|
||||
assert gup_A.shape == (E * r, config.hidden_size), (
|
||||
@@ -501,7 +475,7 @@ class TestPeftLoRAWeightExtraction:
|
||||
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
||||
)
|
||||
|
||||
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
|
||||
# down_proj: K=inter, N=hidden
|
||||
down_A, down_B, down_s = down_lora
|
||||
assert down_A.shape == (E * r, config.intermediate_size), (
|
||||
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
|
||||
|
||||
@@ -24,7 +24,7 @@ from axolotl.monkeypatch.lora_kernels import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -30,19 +29,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestDistMuon:
|
||||
@@ -52,7 +45,7 @@ class TestDistMuon:
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -63,11 +56,12 @@ class TestDistMuon:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"learning_rate": 2e-3,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -82,6 +76,9 @@ class TestDistMuon:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +106,7 @@ class TestDistMuon:
|
||||
def test_lora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -122,14 +119,15 @@ class TestDistMuon:
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"learning_rate": 2e-3,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -144,6 +142,9 @@ class TestDistMuon:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Test module for FSDP1 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP1:
|
||||
@@ -56,7 +49,7 @@ class TestFSDP1:
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -67,11 +60,12 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -87,6 +81,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -126,7 +123,7 @@ class TestFSDP1:
|
||||
def test_lora_sft(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -140,14 +137,15 @@ class TestFSDP1:
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -163,6 +161,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -190,7 +191,7 @@ class TestFSDP1:
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
@@ -203,11 +204,11 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -223,6 +224,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -262,7 +266,7 @@ class TestFSDP1:
|
||||
def test_dpo_lora(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -281,11 +285,11 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -301,6 +305,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
}
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Test module for FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP2:
|
||||
@@ -57,7 +50,7 @@ class TestFSDP2:
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -68,11 +61,12 @@ class TestFSDP2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -86,6 +80,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -114,7 +111,7 @@ class TestFSDP2:
|
||||
def test_lora_sft(self, temp_dir, peft_use_dora):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -128,14 +125,15 @@ class TestFSDP2:
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -149,6 +147,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
# explicitly disable LORA kernels, as they may be auto-enabled
|
||||
"lora_mlp_kernel": False,
|
||||
@@ -180,7 +181,7 @@ class TestFSDP2:
|
||||
def test_lora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -195,11 +196,12 @@ class TestFSDP2:
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -213,6 +215,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
@@ -243,7 +248,7 @@ class TestFSDP2:
|
||||
def test_qlora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -257,14 +262,15 @@ class TestFSDP2:
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -278,6 +284,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -305,7 +314,7 @@ class TestFSDP2:
|
||||
def test_qlora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -321,11 +330,12 @@ class TestFSDP2:
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -339,6 +349,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
@@ -370,7 +383,7 @@ class TestFSDP2:
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
@@ -383,11 +396,11 @@ class TestFSDP2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -401,6 +414,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -428,7 +444,7 @@ class TestFSDP2:
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -445,11 +461,11 @@ class TestFSDP2:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -463,6 +479,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def _run_training(temp_dir, cfg):
|
||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||
cfg = {
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen3-129m",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
|
||||
@@ -8,7 +8,7 @@ from accelerate.test_utils import execute_subprocess_async, get_torch_dist_uniqu
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
@@ -21,7 +21,7 @@ class TestTensorParallel:
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -63,6 +63,6 @@ class TestTensorParallel:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs", max_initial=5.0, max_final=4.7
|
||||
)
|
||||
|
||||
@@ -32,12 +32,12 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"name": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"name": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -47,7 +47,7 @@ MODEL_CONFIGS = [
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"name": "axolotl-ai-co/tiny-gemma2-137m",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -159,7 +159,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"axolotl-ai-co/tiny-gemma2-137m",
|
||||
dtype=torch.float16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
|
||||
@@ -4,14 +4,16 @@ E2E tests for falcon
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestFalconPatched(unittest.TestCase):
|
||||
@@ -19,13 +21,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
Test case for Falcon models
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"load_in_4bit": True,
|
||||
@@ -47,17 +48,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -66,14 +70,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
@@ -88,17 +98,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -107,3 +120,10 @@ class TestFalconPatched(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,12 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
require_torch_2_6_0,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
@@ -22,7 +27,7 @@ class TestMistral(unittest.TestCase):
|
||||
def test_lora_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -45,17 +50,20 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -64,12 +72,19 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -86,17 +101,20 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -105,3 +123,10 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMixtral(unittest.TestCase):
|
||||
@@ -21,8 +25,7 @@ class TestMixtral(unittest.TestCase):
|
||||
def test_qlora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -30,7 +33,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {},
|
||||
@@ -41,17 +44,21 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 3e-3,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 80,
|
||||
"eval_steps": 80,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -60,13 +67,19 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -79,17 +92,21 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"learning_rate": 5e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 80,
|
||||
"eval_steps": 80,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -98,3 +115,10 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -22,8 +22,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mixtral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -57,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
|
||||
@@ -9,7 +9,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestPhiMultipack(unittest.TestCase):
|
||||
@@ -21,7 +25,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
def test_ft_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
@@ -43,17 +47,20 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 50,
|
||||
"save_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -63,12 +70,19 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
@@ -94,17 +108,20 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 50,
|
||||
"save_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -114,3 +131,10 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
def test_is_self_attn_patchable(self):
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_self_attn_is_patchable(),
|
||||
"HF transformers self attention code has changed and isn't patchable",
|
||||
)
|
||||
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
e2e tests for unsloth qlora
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothQLoRA:
|
||||
"""
|
||||
Test class for Unsloth QLoRA Llama models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sdp_attention",
|
||||
[True, False],
|
||||
)
|
||||
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"unsloth_lora_mlp": True,
|
||||
"unsloth_lora_qkv": True,
|
||||
"unsloth_lora_o": True,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"sdp_attention": sdp_attention,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from transformers import AutoModelForCausalLM
|
||||
# Import the actual trainer methods we want to test
|
||||
from axolotl.core.trainers.grpo.async_trainer import AsyncGRPOTrainer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
|
||||
|
||||
|
||||
def _fix_patched_attention(model):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user