Compare commits

..

15 Commits

Author SHA1 Message Date
Wing Lian
d17ed89a3c add missing file 2026-04-21 08:44:01 -04:00
Wing Lian
02e4f2350d fixes for scattermoe from latest peft upgrade 2026-04-21 08:00:16 -04:00
Wing Lian
4195605ab2 fix test dims 2026-04-21 00:44:26 +00:00
Wing Lian
37acb28d02 fix einsum dims 2026-04-20 23:09:47 +00:00
Wing Lian
4a5281e61a Fix shape 2026-04-19 01:53:05 +00:00
Wing Lian
a892d8cce1 chore: lint 2026-04-17 17:48:26 +00:00
Wing Lian
78de2919a6 tiled mlp fix for gemma4 2026-04-16 13:24:41 +00:00
Wing Lian
28283ff373 revert shared_kv_states workaround with transformers 5.5.4 2026-04-15 13:32:59 +00:00
Wing Lian
dc16859983 [gemma4] fix fused RMSNorm+RoPE on hybrid attention models
- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1].
  Triton forward/backward take an n_rot runtime arg that restricts
  rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only
  pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin
  that broadcast over batch.

- Forward: _make_fused_forward used a stale shared_kv_states kwarg the
  current decoder layer no longer passes. Now mirrors stock attention,
  reading/writing past_key_values.shared_layers.
2026-04-15 13:27:31 +00:00
Wing Lian
d4e9cf2eec lint 2026-04-15 13:27:30 +00:00
Wing Lian
53391a10d7 vllm-serve-lora add /v1/completions route + worker pipe lock
The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but
retrace's SWE agent server uses the token-id-aware /v1/completions
endpoint so it can feed raw prompt_token_ids + track per-token
logprobs across multi-turn rollouts. Add the route, mirroring the
shape of /v1/chat/completions but routing to the vLLM worker's
generate() method so prompt_token_ids are passed through as-is.

Also add a worker_pipe_lock around conn.send/conn.recv. The
multiprocessing.Connection to the vLLM worker is a single shared
full-duplex pipe; concurrent HTTP requests interleave pickle frames
on the wire and corrupt the stream (observed as
UnpicklingError: pickle data was truncated, surfacing as 500s).
The agent server fires ~8 concurrent rollout requests at once, so
this was a hard blocker for any multi-concurrent workload. Serialize
access to the pipe per-request round-trip.
2026-04-15 13:27:30 +00:00
Wing Lian
7617b951a8 make _maybe_sync_vllm_weights actually fire in sync mode
Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:

1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
   at the top. The original design assumed sync mode would fall back
   to TRL's stock per-step ``sync_weights`` call inside
   ``_generate_single_turn`` — true for vanilla GRPO but FALSE in
   NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
   server directly and ``_generate_single_turn`` is never invoked.
   Result: no sync ever happened in NeMo Gym sync mode.

2. ``step % vllm_sync_interval`` would TypeError on the first call if
   ``vllm_sync_interval`` was unset (the default for any config that
   doesn't explicitly set it).

3. The ``_generate_single_turn`` patch installed
   ``vllm_generation.sync_weights = lambda: None`` unconditionally
   for vllm_lora_sync runs. That's correct in async-prefetch mode
   (BG thread can't safely sync) but wrong in sync mode: TRL's
   per-step auto-sync inside ``_generate_single_turn`` was the
   fallback that the early return in (1) was assuming, and the
   no-op patch was killing it.

Fix:
  - Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
    is now the canonical sync trigger and runs in both modes from
    ``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
  - Default ``vllm_sync_interval`` to 1 when unset.
  - In the ``_generate_single_turn`` patch, route sync_weights to
    ``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
    in async mode for the BG-thread safety reason).
2026-04-15 13:27:30 +00:00
Wing Lian
e993ed5208 retry head-server probe with longer timeout
``get_server_configs`` was hardcoded to a 5s timeout with no retry.
That's empirically too tight to survive a kill-and-relaunch cycle:
when the agent server is finishing in-flight rollouts from a prior
run, it can take 10-30s to respond to /global_config_dict_yaml, and
the trainer would crash at startup with a ReadTimeoutError.

Bump the per-attempt timeout to 30s and retry up to 3 times with a
2s/4s backoff. The retry intentionally raises a RuntimeError after
the third failure rather than returning empty config — silent
failure here would let training proceed with no agent servers
discovered, which is also a no-op trainer.
2026-04-15 13:27:30 +00:00
Wing Lian
69f165b39b probe vLLM weight-sync routes and select transport per server
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).

Replace both patches with:

1. A probe of the configured vLLM server's /openapi.json at
   pre_model_load. Three transports are recognized:
     - NCCL (/init_communicator/ + /update_named_param/) — TRL serve
       and axolotl vllm-serve both expose this
     - LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
     - HTTP base64 full-weight (/http_update_weights/) — axolotl
       vllm-serve only

2. A pure-logic ``select_weight_sync_transport`` that picks the
   right one for (server caps × adapter type).

3. ``init_communicator`` is only patched out when the server has no
   NCCL routes; against TRL/axolotl serve modules it stays live so
   full-finetune NCCL sync works.

4. ``post_trainer_create`` uses the selection table to install LoRA
   filesystem sync OR leave the standard NCCL flow alone OR raise
   NotImplementedError (HTTP — pending) OR raise a precise diagnosis
   when no transport is viable. No more silent no-op trainers.
2026-04-15 13:27:30 +00:00
Wing Lian
80a97f192b validate batch shape against num_generations at config time
Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:

  - effective generation_batch_size (or mb*GA fallback) must be divisible
    by trl.num_generations, with a hint pointing at the smallest GA bump
    that fixes the violation
  - num_generations >= 2 (group-relative advantage needs variance; with
    num_gen=1 the policy never updates)
  - When world_size > 1, effective gbs >= num_generations * world_size

11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
2026-04-15 13:27:30 +00:00
110 changed files with 1929 additions and 3227 deletions

View File

@@ -31,10 +31,7 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
# Install axolotl + dev and test dependencies from lockfile
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test

View File

@@ -6,7 +6,7 @@ on:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"

View File

@@ -3,15 +3,17 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- "tests/e2e/multigpu/**.py"
- "pyproject.toml"
- ".github/workflows/multi-gpu-e2e.yml"
- "scripts/cutcrossentropy_install.py"
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
- "src/axolotl/utils/distributed.py"
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
@@ -31,19 +33,19 @@ jobs:
fail-fast: false
matrix:
include:
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
@@ -51,6 +53,7 @@ jobs:
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -72,7 +75,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -8,9 +8,6 @@ on:
permissions: {}
env:
UV_SYSTEM_PYTHON: "1"
jobs:
setup_release:
name: Create Release
@@ -44,15 +41,11 @@ jobs:
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: |
uv pip install wheel packaging
uv pip install --no-build-isolation -e .
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag

View File

@@ -2,18 +2,15 @@ name: Tests Nightly against upstream main
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- ".github/workflows/tests-nightly.yml"
- '.github/workflows/tests-nightly.yml'
permissions:
contents: read
env:
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
name: pre-commit
@@ -23,7 +20,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip" # caching pip dependencies
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -46,7 +43,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.9.1", "2.10.0"]
timeout-minutes: 20
@@ -64,34 +61,36 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
- name: Install PyTorch
run: |
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
- name: Install dependencies
run: |
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Override with nightly HF packages
run: |
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
@@ -103,6 +102,9 @@ jobs:
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -134,6 +136,7 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -154,7 +157,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:

View File

@@ -6,19 +6,21 @@ on:
branches:
- "main"
paths:
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
@@ -31,7 +33,6 @@ permissions:
env:
TRANSFORMERS_IS_CI: "yes"
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
@@ -43,7 +44,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip" # caching pip dependencies
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -93,25 +94,32 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
@@ -180,27 +188,33 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch
run: |
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
uv pip install packaging setuptools_scm build wheel psutil
pip3 show torch
python -m build --no-isolation --sdist
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
@@ -277,6 +291,7 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -297,7 +312,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
@@ -359,7 +374,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
| Method | Config Key | When to Use |
|--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |

View File

@@ -1,6 +1,7 @@
include requirements.txt
include README.md
include LICENSE
include VERSION
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
include AGENTS.md
recursive-include docs/agents *.md

View File

@@ -95,11 +95,14 @@ Features:
### Installation
```bash
# install uv if you don't already have it installed (restart shell after)
curl -LsSf https://astral.sh/uv/install.sh | sh
#### Using uv (recommended)
# change depending on system
```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
@@ -109,6 +112,23 @@ source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using pip
```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
@@ -118,7 +138,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).

View File

@@ -134,6 +134,7 @@ quartodoc:
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
@@ -326,6 +327,7 @@ website:
- section: "Advanced Features"
contents:
- docs/fsdp_qlora.qmd
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd

View File

@@ -22,6 +22,15 @@ WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
@@ -31,21 +40,11 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
# Override with nightly HF packages for nightly builds
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

54
cicd/Dockerfile.jinja Normal file
View File

@@ -0,0 +1,54 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,7 +1,7 @@
#!/bin/bash
set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail
for i in 1 2 3; do

View File

@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -32,7 +32,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -33,6 +33,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean

View File

@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo`
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
- Overfitting → `rl: ipo`
- VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto`

View File

@@ -76,9 +76,8 @@ datasets:
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
```bash
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
#### Remote Hosts
@@ -209,17 +208,17 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, install Axolotl with dev dependencies:
You will now be in the container. Next, perform an editable install of Axolotl:
```bash
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Attach To Container

View File

@@ -6,30 +6,23 @@ format:
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at
[Docker Hub](https://hub.docker.com/u/axolotlai).
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
:::
::: {.callout-tip}
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
@@ -39,10 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu128-2.9.1`
- `main-base-py3.12-cu128-2.10.0`
- `main-base-py3.12-cu130-2.9.1`
- `main-base-py3.12-cu130-2.10.0`
## Main
@@ -50,10 +41,11 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
#### Image
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
@@ -61,7 +53,7 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
main-latest
# nightly build
@@ -79,12 +71,11 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu128-2.9.1`
- `main-py3.12-cu128-2.10.0`
- `main-py3.12-cu130-2.9.1`
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20260315-py3.11-cu128-2.9.1`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.12.0`
## Cloud
@@ -99,10 +90,11 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
#### Image
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format

View File

@@ -15,30 +15,64 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.9.0
- PyTorch ≥2.6.0
## Installation {#sec-installation}
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
:::
### Quick Install {#sec-uv}
### PyPI Installation (Recommended) {#sec-pypi}
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
Install uv if not already installed:
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
```{.bash}
export UV_TORCH_BACKEND=cu128 # or cu130
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build}
@@ -48,17 +82,14 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed
source .venv/bin/activate
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
For development with Docker:
@@ -75,12 +106,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl-uv:main-latest
axolotlai/axolotl:main-latest
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
@@ -91,7 +122,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud-uv:main-latest`
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
@@ -110,7 +141,7 @@ For providers supporting Docker:
### macOS {#sec-macos}
```{.bash}
uv pip install --no-build-isolation -e '.'
pip3 install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
@@ -121,44 +152,21 @@ See @sec-troubleshooting for Mac-specific issues.
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Migrating from pip to uv {#sec-migrating}
## Environment Managers {#sec-env-managers}
If you have an existing pip-based Axolotl installation, you can migrate to uv:
### Conda/Pip venv {#sec-conda}
```{.bash}
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
## Using pip (Alternative) {#sec-pip}
If you are unable to install uv, you can still use pip directly.
::: {.callout-important}
Please make sure to have PyTorch installed before installing Axolotl with pip.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
1. Install Python ≥3.11
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -320,10 +320,8 @@ The input format is a simple JSON input with customizable fields based on the ab
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: dpo
dpo_loss_type: ["ipo"]
rl: ipo
```
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
### ORPO

53
docs/unsloth.qmd Normal file
View File

@@ -0,0 +1,53 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
::: {.callout-important}
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
:::
### Installation
The following will install the correct unsloth and extras from source.
```bash
python scripts/unsloth_install.py | sh
```
### Usage
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -15,7 +15,8 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
@@ -34,7 +35,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
**LFM2-MoE**
```bash
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
@@ -44,7 +45,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
uv pip uninstall causal-conv1d
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -15,7 +15,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
@@ -30,7 +31,7 @@ python scripts/cutcrossentropy_install.py | sh
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
@@ -66,7 +67,7 @@ If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
@@ -91,7 +92,7 @@ If those didn't help, please try the below solutions:
```
```bash
uv pip install . --no-build-isolation --no-deps
pip3 install . --no-build-isolation --no-deps
```
## Optimization Guides

View File

@@ -17,7 +17,8 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -16,7 +16,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -10,16 +10,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:
```bash
uv pip install timm==1.0.17
pip3 install timm==1.0.17
# for loading audio data
uv pip install librosa==0.11.0
pip3 install librosa==0.11.0
```
3. Download sample dataset files

View File

@@ -14,7 +14,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
@@ -86,7 +87,7 @@ for more information about using a special vllm-openai docker image for inferenc
Optionally, vLLM can be installed from nightly:
```bash
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash

View File

@@ -15,7 +15,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,7 +13,8 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl.
2. Install `timm` for vision model support:
```bash
uv pip install timm==1.0.19
pip install timm==1.0.19
```
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.

View File

@@ -14,7 +14,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
uv pip install 'mistral-common[opencv]==1.8.5'
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:

View File

@@ -23,7 +23,7 @@ Note: This is still experimental given it is based on transformers v5 RC.
git checkout transformers-v5
# Install packages for transformers v5
uv pip install -e .
pip install -e .
```
4. Run the fine-tuning:

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
uv pip install 'mistral-common[opencv]==1.8.6'
pip install 'mistral-common[opencv]==1.8.6'
```
2. Download the example dataset image:

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
uv pip install 'mistral-common[opencv]==1.8.5'
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:

View File

@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
3. Install transformers from main
```bash
uv pip install git+https://github.com/huggingface/transformers.git
pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:

View File

@@ -12,7 +12,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
3. Install FLA for improved performance
```bash
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
4. Run the finetuning example:

View File

@@ -10,7 +10,7 @@
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.

View File

@@ -11,7 +11,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,13 +13,14 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
uv pip install num2words==0.5.14
pip3 install num2words==0.5.14
```
3. Run the finetuning example:

View File

@@ -12,15 +12,16 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Please install the below.
```bash
# audio
uv pip install librosa==0.11.0
uv pip install 'mistral_common[audio]==1.8.3'
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -1,165 +1,15 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
build-backend = "setuptools.build_meta"
[project]
name = "axolotl"
dynamic = ["version"]
dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
# license = "Apache-2.0"
dependencies = [
# Core ML stack
"torch>=2.6.0",
"packaging==26.0",
"huggingface_hub>=1.1.7",
"peft>=0.19.1,<0.20.0",
"tokenizers>=0.22.1",
"transformers==5.5.4",
"accelerate==1.13.0",
"datasets>=4.8.4,<4.9.0",
"trl==1.1.0",
"hf_xet==1.4.3",
"kernels==0.13.0",
"trackio>=0.16.1",
"typing-extensions>=4.15.0",
"optimum==1.16.2",
"hf_transfer",
"sentencepiece",
"gradio>=6.2.0,<7.0",
"modal==1.3.0.post1",
"pydantic>=2.10.6",
"addict",
"fire",
"PyYAML>=6.0",
"requests",
"wandb",
"einops",
"colorama",
"numba>=0.61.2",
"numpy>=2.2.6",
# Evaluation & metrics
"evaluate==0.4.1",
"scipy",
"nvidia-ml-py==12.560.30",
"art",
"tensorboard",
"python-dotenv==1.0.1",
# Remote filesystems
"s3fs>=2024.5.0",
"gcsfs>=2025.3.0",
"adlfs>=2024.5.0",
"ocifs==1.3.2",
"zstandard==0.22.0",
"fastcore",
# lm eval harness
"lm_eval==0.4.11",
"langdetect==1.0.9",
"immutabledict==4.2.0",
"antlr4-python3-runtime==4.13.2",
"schedulefree==1.4.1",
"openenv-core==0.1.0",
# Axolotl contribs
"axolotl-contribs-lgpl==0.0.7",
"axolotl-contribs-mit==0.0.6",
# Telemetry
"posthog==6.7.11",
"mistral-common==1.11.0",
# Platform-specific (Linux only)
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
"triton>=3.4.0 ; sys_platform != 'darwin'",
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
# Architecture-specific
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
]
[project.optional-dependencies]
flash-attn = ["flash-attn==2.8.3"]
ring-flash-attn = [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
]
deepspeed = [
"deepspeed>=0.18.6,<0.19.0",
"deepspeed-kernels",
]
mamba-ssm = [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
]
auto-gptq = [
"auto-gptq==0.5.1",
]
mlflow = [
"mlflow",
]
galore = [
"galore_torch",
]
apollo = [
"apollo-torch",
]
optimizers = [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
]
ray = [
"ray[train]>=2.52.1",
]
vllm = [
"vllm>=0.15.0",
]
llmcompressor = [
"llmcompressor>=0.10.0",
]
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
opentelemetry = [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
]
[dependency-groups]
dev = [
"black",
"mypy",
"pre-commit",
"types-requests",
"quartodoc",
"jupyter",
"blobfile",
"tiktoken",
]
test = [
"codecov",
"codecov-cli",
"pytest",
"pytest-cov",
"pytest-retry",
"pytest-sugar",
"pytest-xdist",
"tbparse",
]
[project.scripts]
axolotl = "axolotl.cli.main:main"
@@ -168,15 +18,18 @@ Homepage = "https://axolotl.ai/"
Documentation = "https://docs.axolotl.ai/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools]
include-package-data = true
[tool.setuptools_scm]
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
@@ -214,43 +67,5 @@ markers = [
"slow: marks tests as slow",
]
# UV specific configuration
[tool.uv]
prerelease = "allow"
conflicts = [
[
{ package = "axolotl" },
{ extra = "vllm" },
],
[
{ package = "axolotl" },
{ extra = "flash-attn" },
],
[
{ package = "axolotl" },
{ extra = "ring-flash-attn" },
],
[
{ package = "axolotl" },
{ extra = "mamba-ssm" },
],
[
{ package = "axolotl" },
{ extra = "auto-gptq" },
],
[
{ package = "axolotl" },
{ extra = "fbgemm-gpu" },
],
[
{ package = "axolotl" },
{ extra = "llmcompressor" },
],
]
[tool.uv.extra-build-dependencies]
mamba-ssm = [{ requirement = "torch", match-runtime = true }]
causal-conv1d = [{ requirement = "torch", match-runtime = true }]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]
auto-gptq = [{ requirement = "torch", match-runtime = true }]
axolotl = ["huggingface_hub"]

8
requirements-dev.txt Normal file
View File

@@ -0,0 +1,8 @@
black
mypy
pre-commit
types-requests
quartodoc
jupyter
blobfile
tiktoken

8
requirements-tests.txt Normal file
View File

@@ -0,0 +1,8 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry
pytest-sugar
pytest-xdist
tbparse

78
requirements.txt Normal file
View File

@@ -0,0 +1,78 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.19.0,<0.20.0
tokenizers>=0.22.1
transformers==5.5.4
accelerate==1.13.0
datasets>=4.8.4,<4.9.0
deepspeed>=0.18.6,<0.19.0
trl==1.1.0
hf_xet==1.4.3
kernels==0.13.0
fla-core==0.4.1
flash-linear-attention==0.4.1
trackio>=0.16.1
typing-extensions>=4.15.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio>=6.2.0,<7.0
modal==1.3.0.post1
pydantic>=2.10.6
addict
fire
PyYAML>=6.0
requests
wandb
einops
colorama
numba>=0.61.2
numpy>=2.2.6
# qlora things
evaluate==0.4.1
scipy
nvidia-ml-py==12.560.30
art
tensorboard
python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2025.3.0
adlfs>=2024.5.0
ocifs==1.3.2
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.11
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.17.0
openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.11.0

View File

@@ -0,0 +1,40 @@
# noqa
import sys
try:
import torch
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

230
setup.py Normal file
View File

@@ -0,0 +1,230 @@
"""setup.py for axolotl"""
import os
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements(extras_require_map):
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = "deepspeed" in line or "mamba-ssm" in line
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"xformers",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.8.0" # default to torch 2.8.0
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm>=0.19.0"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
if install_xformers:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
if install_xformers:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if install_xformers:
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links, extras_require_map
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
"r",
encoding="utf-8",
) as fin:
version_ = fin.read().strip()
return version_
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.18.2",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]>=2.52.1",
],
"vllm": [
"vllm==0.10.0",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
"opentelemetry": [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup(
version=get_package_version(),
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require=extras_require_build,
)

View File

@@ -339,11 +339,7 @@ def _build_peft_layer_and_get_delta(
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
delta = layer.get_delta_weight(adapter_name)
# peft >=0.19.1 may return delta with transposed dims for 3D params
if delta.shape != base_tensor.shape and delta.ndim == 3:
delta = delta.transpose(1, 2).contiguous()
return delta
return layer.get_delta_weight(adapter_name)
elif (
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
):

View File

@@ -20,16 +20,8 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.DPO:
if cfg.dpo_loss_type is not None:
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
if cfg.dpo_loss_weights is not None:
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = ["ipo"]
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing

View File

@@ -1,27 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Hatchery/Tinker remote training integration for Axolotl.
Routes axolotl's preprocessed data to a remote training API (Tinker or
Hatchery) instead of running forward/backward locally. The remote
service handles model weights, LoRA adapters, and gradient updates.
"""
from .args import HatcheryArgs, HatcheryConfig
from .plugin import HatcheryPlugin
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
# Usage:
# plugins:
# - axolotl.integrations.hatchery.HatcheryPlugin
#
# hatchery:
# backend: tinker # or "hatchery"
# lora_rank: 32
# loss_fn: cross_entropy # SFT
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
#
# learning_rate: 1e-4 # top-level, not under hatchery:

View File

@@ -1,62 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Pydantic config schema for the Hatchery integration."""
from __future__ import annotations
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
class HatcheryConfig(BaseModel):
"""Nested config under `hatchery:` in the axolotl YAML.
Only contains hatchery-specific settings. Standard training params
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
gradient_accumulation_steps) are read from axolotl's top-level config.
"""
# Backend & connection
backend: Literal["tinker", "hatchery"] = "tinker"
base_url: Optional[str] = None
api_key: Optional[str] = None
project_id: Optional[str] = None
# LoRA config sent to remote
lora_rank: int = Field(32, ge=1, le=256)
train_attn: bool = True
train_mlp: bool = True
train_unembed: bool = True
# Loss function
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
"cross_entropy"
)
loss_fn_config: Optional[dict[str, Any]] = None
# Pipelining: submit next batch before awaiting previous result
pipeline: bool = True
# Sampling params (for RL flows)
max_sample_tokens: int = 256
sample_temperature: float = 1.0
num_samples: int = 4
# Reward functions (for RL) — list of fully qualified names
reward_funcs: Optional[list[str]] = None
# Checkpointing
save_steps: Optional[int] = None
save_name_prefix: str = "checkpoint"
# Timeout per future (seconds)
future_timeout: float = 600.0
class HatcheryArgs(BaseModel):
"""Top-level mixin that adds the nested `hatchery:` field."""
hatchery: Optional[HatcheryConfig] = None

View File

@@ -1,160 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
Both Tinker and Hatchery expect the client to apply the causal LM shift:
Original tokens: [t0, t1, t2, ..., t_{L-1}]
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
"""
from __future__ import annotations
from typing import Any
import torch
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
"""Serialize a tensor to the TensorData wire dict."""
flat = t.detach().cpu().flatten()
dtype_map = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
}
return {
"dtype": dtype_map.get(flat.dtype, "float32"),
"shape": list(t.shape),
"data": flat.tolist(),
}
def _make_datum(
tokens: list[int],
loss_fn_inputs: dict[str, torch.Tensor],
) -> dict[str, Any]:
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
return {
"model_input": {
"chunks": [{"type": "encoded_text", "tokens": tokens}],
},
"loss_fn_inputs": {
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
},
}
def datums_to_tinker(datums: list[dict[str, Any]]):
"""Wrap plain-dict datums into tinker.types.Datum objects.
Both the Tinker SDK and updated Hatchery client accept these.
"""
import tinker.types as tt
result = []
for d in datums:
tokens = d["model_input"]["chunks"][0]["tokens"]
tinker_inputs = {}
for key, wire in d["loss_fn_inputs"].items():
tinker_inputs[key] = tt.TensorData(
data=wire["data"],
dtype=wire["dtype"],
shape=wire["shape"],
)
result.append(
tt.Datum(
model_input=tt.ModelInput.from_ints(tokens),
loss_fn_inputs=tinker_inputs,
)
)
return result
def batch_to_datums_sft(
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
ids = ids[:seq_len]
lbl = lbl[:seq_len]
model_tokens = ids[:-1].tolist()
shifted_labels = lbl[1:]
target_tokens = shifted_labels.clone()
weights = (shifted_labels != -100).float()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"weights": weights,
},
)
)
return datums
def batch_to_datums_rl(
input_ids: torch.Tensor,
labels: torch.Tensor,
logprobs: torch.Tensor,
advantages: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> list[dict[str, Any]]:
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
batch_size = input_ids.size(0)
datums = []
for i in range(batch_size):
ids = input_ids[i]
lbl = labels[i]
if attention_mask is not None:
seq_len = int(attention_mask[i].sum().item())
else:
seq_len = ids.size(0)
ids = ids[:seq_len]
lbl = lbl[:seq_len]
lp = logprobs[i, :seq_len]
adv = advantages[i, :seq_len]
model_tokens = ids[:-1].tolist()
target_tokens = lbl[1:].clone()
target_tokens[target_tokens == -100] = 0
datums.append(
_make_datum(
model_tokens,
{
"target_tokens": target_tokens,
"logprobs": lp[1:],
"advantages": adv[1:],
},
)
)
return datums

View File

@@ -1,87 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
Creates a dataset with chat-formatted prompts that include
a hidden gold answer tag for the reward function.
Run:
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
"""
import os
import re
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
def extract_boxed(text: str) -> str:
match = re.search(r"\\boxed\{", text)
if not match:
return ""
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else ""
def main():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
level = os.environ.get("MATH_LEVEL", "Level 1")
filtered_rows = [x for x in ds if x["level"] == level]
print(f"{level} algebra: {len(filtered_rows)} problems")
rows = []
for prob in filtered_rows:
gold = extract_boxed(prob["solution"])
if not gold:
continue
# Format as chat prompt with hidden gold tag
prompt = (
f"Solve the following math problem. "
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
f"{prob['problem']}"
f"<|gold|>{gold}<|/gold|>"
)
# Tokenize the prompt
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
rows.append(
{
"input_ids": prompt_ids,
"labels": [-100] * len(prompt_ids),
"attention_mask": [1] * len(prompt_ids),
}
)
out = Dataset.from_list(rows)
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
out.save_to_disk(out_dir)
print(f"Saved {len(out)} examples to {out_dir}")
if rows:
print(
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
f"-{max(len(r['input_ids']) for r in rows)}"
)
if __name__ == "__main__":
main()

View File

@@ -1,47 +0,0 @@
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
#
# Prep:
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
#
# Run:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: importance_sampling
max_sample_tokens: 2048
sample_temperature: 0.7
num_samples: 4
pipeline: true
save_steps: 5
reward_funcs:
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
datasets:
- path: ./data/math_rl_level1
ds_type: arrow
type: completion
sequence_len: 2048
learning_rate: 5.0e-5
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
max_steps: 10
num_epochs: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-rl-math

View File

@@ -1,42 +0,0 @@
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
#
# Usage:
# export TINKER_API_KEY="your-key"
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker
lora_rank: 16
loss_fn: cross_entropy
pipeline: true
save_steps: 10
datasets:
- path: TeichAI/kimi-k2-thinking-1000x
split: train[:50]
type: chat_template
chat_template: qwen3
split_thinking: true
chat_template: qwen3
sequence_len: 2048
learning_rate: 3.0e-4
optimizer: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.01
max_grad_norm: 1.0
num_epochs: 1
max_steps: 20
micro_batch_size: 2
gradient_accumulation_steps: 1
logging_steps: 1
output_dir: ./outputs/tinker-sft

View File

@@ -1,147 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
from __future__ import annotations
import torch
from peft import PeftModel
from transformers import AutoConfig, PreTrainedModel, Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HatcheryPlugin(BasePlugin):
"""Plugin that replaces local training with remote API calls.
Activated by adding to the axolotl YAML:
plugins:
- axolotl.integrations.hatchery.HatcheryPlugin
hatchery:
backend: tinker # or "hatchery"
lora_rank: 32
loss_fn: cross_entropy
# ... see HatcheryConfig for full options
"""
def get_input_args(self) -> str:
return "axolotl.integrations.hatchery.args.HatcheryArgs"
def register(self, cfg: dict):
"""Auto-set config values needed for remote training."""
if cfg.get("remove_unused_columns") is None:
cfg["remove_unused_columns"] = False
def pre_model_load(self, cfg: DictDefault):
"""Replace model loading with a tiny stub."""
hcfg = cfg.hatchery or {}
backend = (
hcfg.get("backend", "tinker")
if isinstance(hcfg, dict)
else getattr(hcfg, "backend", "tinker")
)
LOG.info(
f"Hatchery plugin active: training dispatched to remote "
f"{backend} API. Skipping local model weight loading."
)
from axolotl.loaders import ModelLoader
def _stub_build_model(loader_self) -> bool:
base_model = loader_self.cfg.base_model
LOG.info(f"Skipping model weight loading for: {base_model}")
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
)
class _Stub(PreTrainedModel):
config_class = type(config)
_no_split_modules: list[str] = []
supports_gradient_checkpointing = False
def __init__(self, cfg):
super().__init__(cfg)
vocab_size = getattr(cfg, "vocab_size", 32000)
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
pass
def get_output_embeddings(self):
return None
loader_self.model = _Stub(config)
return True
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Return the appropriate remote trainer class."""
hcfg = cfg.hatchery
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
from .rl_trainer import HatcheryRLTrainer
return HatcheryRLTrainer
from .trainer import HatcheryTrainer
return HatcheryTrainer
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
model._hatchery_remote = True
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
LOG.info(
"Hatchery: skipping local model save (weights are on remote API). "
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
)
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Inject hatchery config + axolotl training params into the trainer."""
from .args import HatcheryConfig
from .rl_trainer import HatcheryRLTrainer
from .trainer import HatcheryTrainer
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
return
hcfg = cfg.hatchery
if isinstance(hcfg, dict):
hatchery_config = HatcheryConfig(**hcfg)
elif hcfg is None:
hatchery_config = HatcheryConfig()
else:
hatchery_config = hcfg
trainer.hatchery_args = hatchery_config
trainer._base_model_name = cfg.base_model
# Pull standard training params from axolotl config so they
# don't need to be duplicated under hatchery:
trainer._optim_params = {
"learning_rate": cfg.learning_rate
if cfg.learning_rate is not None
else 1e-4,
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
"grad_clip_norm": cfg.max_grad_norm
if cfg.max_grad_norm is not None
else 0.0,
}

View File

@@ -1,3 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0

View File

@@ -1,78 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Math reward function for hendrycks_math GRPO training.
Uses math_verify for robust answer comparison. Falls back to
exact string match of \\boxed{} content only when math_verify
is unavailable.
"""
from __future__ import annotations
import logging
import re
LOG = logging.getLogger(__name__)
def extract_boxed(text: str) -> str | None:
"""Extract \\boxed{...} answer handling nested braces."""
match = re.search(r"\\boxed\{", text)
if not match:
return None
start = match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
return text[start : i - 1] if depth == 0 else None
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
"""Score completions by checking if \\boxed{} answer matches the gold answer.
The gold answer is extracted from the prompt (appended as a hidden
tag by the dataset preprocessing). Format:
... <|gold|>ANSWER<|/gold|>
"""
rewards = []
for prompt, completion in zip(prompts, completions, strict=True):
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
if not gold_match:
rewards.append(0.0)
continue
gold_answer = gold_match.group(1).strip()
pred_answer = extract_boxed(completion)
if pred_answer is None:
rewards.append(0.0)
continue
verified = None
try:
from math_verify import parse, verify
gold_parsed = parse(gold_answer)
pred_parsed = parse(pred_answer)
verified = verify(gold_parsed, pred_parsed)
except Exception:
LOG.debug(
"math_verify unavailable or failed, using string fallback",
exc_info=True,
)
if verified is not None:
rewards.append(1.0 if verified else 0.0)
elif pred_answer.strip() == gold_answer.strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards

View File

@@ -1,409 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
Full RL loop per step:
1. Extract prompts from dataset batch
2. Sample N completions per prompt via remote SamplingClient
3. Score completions with local reward functions
4. Compute GRPO-style advantages (per-group normalization)
5. Send (prompt+completion, logprobs, advantages) as forward_backward
6. Optimizer step
"""
from __future__ import annotations
import importlib
import inspect
import re
import time
from typing import Any, Callable, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_rl, datums_to_tinker
from .trainer import _create_training_client
LOG = get_logger(__name__)
def _load_reward_func(fqn: str) -> Callable:
"""Load a reward function from a fully qualified name like 'module.func'."""
module_path = ".".join(fqn.split(".")[:-1])
func_name = fqn.split(".")[-1]
mod = importlib.import_module(module_path)
func = getattr(mod, func_name)
if len(inspect.signature(func).parameters) < 2:
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
return func
class HatcheryRLTrainer(AxolotlTrainer):
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
_reward_functions: list[Callable]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
self._reward_functions = []
def _ensure_reward_functions(self):
if self._reward_functions:
return
args = self.hatchery_args
if not args or not args.reward_funcs:
raise ValueError(
"No reward functions configured. Set hatchery.reward_funcs "
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
)
for fqn in args.reward_funcs:
self._reward_functions.append(_load_reward_func(fqn))
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
def _get_training_client(self):
if self._training_client is not None:
return self._training_client
self._training_client = _create_training_client(
self.hatchery_args, self._base_model_name
)
LOG.info(
f"Remote RL session created: backend={self.hatchery_args.backend}, "
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
)
return self._training_client
def _sample_completions(self, prompt_ids_list: list[list[int]]):
"""Sample completions for prompts via remote API."""
import tinker.types as tt
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
results = []
sc = tc.save_weights_and_get_sampling_client()
for prompt_ids in prompt_ids_list:
if hasattr(sc, "sampling_session_id"):
sample_result = sc.sample(
prompt_ids,
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
n=args.num_samples,
).result(timeout=args.future_timeout)
else:
mi = tt.ModelInput.from_ints(prompt_ids)
sp = tt.SamplingParams(
max_tokens=args.max_sample_tokens,
temperature=args.sample_temperature,
top_p=0.95,
top_k=-1,
)
sample_result = sc.sample(
prompt=mi,
num_samples=args.num_samples,
sampling_params=sp,
).result(timeout=args.future_timeout)
sequences = (
sample_result.sequences
if hasattr(sample_result, "sequences")
else sample_result.get("sequences", [])
)
for seq in sequences:
tokens = (
list(seq.tokens)
if hasattr(seq, "tokens")
else seq.get("tokens", [])
)
logprobs = (
list(seq.logprobs)
if hasattr(seq, "logprobs") and seq.logprobs
else seq.get("logprobs", [])
)
results.append(
{
"tokens": list(prompt_ids) + tokens,
"completion_tokens": tokens,
"logprobs": logprobs,
"prompt_len": len(prompt_ids),
}
)
return results
def _compute_rewards(
self, prompts: list[str], completions: list[str]
) -> list[float]:
total_rewards = [0.0] * len(completions)
for reward_fn in self._reward_functions:
rewards = reward_fn(prompts, completions)
for i, r in enumerate(rewards):
total_rewards[i] += r
return total_rewards
@staticmethod
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
advantages = []
for i in range(0, len(rewards), group_size):
group = rewards[i : i + group_size]
mean = sum(group) / len(group)
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
std = var**0.5 if var > 1e-8 else 1.0
advantages.extend([(r - mean) / std for r in group])
return advantages
def _do_optim_step(self):
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
self._ensure_reward_functions()
train_dataloader = self.get_train_dataloader()
num_train_epochs = int(self.args.num_train_epochs)
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
LOG.info(
f"Remote RL training: max_steps={max_steps}, "
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
tokenizer = self.processing_class
global_step = 0
total_loss = 0.0
total_reward = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
for batch in train_dataloader:
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
prompt_ids_batch = batch["input_ids"]
# Full prompt text (with gold tag) for reward scoring
prompt_texts = tokenizer.batch_decode(
prompt_ids_batch, skip_special_tokens=False
)
# Strip <|gold|>...<|/gold|> from token ids before
# sending to the model for sampling — the gold answer
# must only be visible to the local reward function.
sampling_prompts = []
for prompt_text in prompt_texts:
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
sampling_prompts.append(clean_ids)
# 1. Sample completions (without gold answer)
t0 = time.time()
samples = self._sample_completions(sampling_prompts)
t_sample = time.time() - t0
if not samples:
LOG.warning("No samples generated, skipping step")
continue
LOG.info(
f"Sampled {len(samples)} completions, "
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
)
# 2. Decode and score
completion_texts = [
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
for s in samples
]
sample_prompts = []
for prompt_text in prompt_texts:
sample_prompts.extend([prompt_text] * args.num_samples)
rewards = self._compute_rewards(sample_prompts, completion_texts)
# 3. GRPO advantages
advantages_list = self._compute_advantages(
rewards, group_size=args.num_samples
)
# 4. Build training data
all_datums = []
for i, sample in enumerate(samples):
full_tokens = sample["tokens"]
prompt_len = sample["prompt_len"]
seq_len = len(full_tokens)
input_ids = torch.tensor([full_tokens], dtype=torch.long)
labels = torch.full((1, seq_len), -100, dtype=torch.long)
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
logprobs_t = torch.zeros(1, seq_len)
if sample["logprobs"]:
lp = sample["logprobs"][: seq_len - prompt_len]
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
lp
)
adv_t = torch.zeros(1, seq_len)
adv_t[0, prompt_len:] = advantages_list[i]
all_datums.extend(
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
)
# 5. Forward backward (one datum at a time for memory) + optim
t0 = time.time()
tc = self._get_training_client()
step_loss = 0.0
for datum in all_datums:
fb_future = tc.forward_backward(
datums_to_tinker([datum]),
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
fb_result = fb_future.result(timeout=args.future_timeout)
if hasattr(fb_result, "metrics"):
step_loss += float(
(fb_result.metrics or {}).get("loss:sum", 0.0)
)
elif isinstance(fb_result, dict):
step_loss += float(
fb_result.get("metrics", {}).get("loss:sum", 0.0)
)
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
t_train = time.time() - t0
mean_reward = sum(rewards) / len(rewards)
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
global_step += 1
total_loss += step_loss
total_reward += mean_reward
self.state.global_step = global_step
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
LOG.info(
f"[step {global_step}/{max_steps}] "
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
f"sample={t_sample:.1f}s train={t_train:.1f}s "
f"{elapsed / global_step:.1f}s/step"
)
self.log(
{
"loss": step_loss,
"reward": mean_reward,
"accuracy": accuracy,
"mean_abs_advantage": mean_adv,
"learning_rate": self._optim_params["learning_rate"],
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
avg_reward = total_reward / max(global_step, 1)
LOG.info(
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
f"avg_reward={avg_reward:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={
"train_loss": avg_loss,
"train_reward": avg_reward,
"train_runtime": elapsed,
},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
)

View File

@@ -1,327 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Remote trainer that dispatches to Tinker or Hatchery API."""
from __future__ import annotations
import os
import time
from typing import Any, Optional
import torch
from transformers.trainer_utils import TrainOutput
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.logging import get_logger
from .args import HatcheryConfig
from .data import batch_to_datums_sft, datums_to_tinker
LOG = get_logger(__name__)
def _extract_loss(result) -> float:
"""Extract loss:sum from a forward_backward result.
Tinker's cross_entropy (and other losses) return the SUM of per-token
losses, not the mean. This is by design — it lets users control
normalization via the weights tensor. The trainer logs this raw sum;
users who want per-token loss should divide by number of active tokens.
"""
if hasattr(result, "metrics"):
metrics = result.metrics or {}
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
if isinstance(result, dict):
metrics = result.get("metrics", {})
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
return 0.0
def _create_training_client(args: HatcheryConfig, base_model: str):
"""Create a training client for either Tinker or Hatchery backend."""
if args.backend == "tinker":
import tinker
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
if not api_key:
raise ValueError(
"Tinker API key required. Set `hatchery.api_key` in config "
"or TINKER_API_KEY env var."
)
os.environ["TINKER_API_KEY"] = api_key
service = tinker.ServiceClient(project_id=args.project_id)
return service.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_mlp=args.train_mlp,
train_attn=args.train_attn,
train_unembed=args.train_unembed,
)
from hatchery.core.client import HatcheryClient
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
return client.create_lora_training_client(
base_model=base_model,
rank=args.lora_rank,
train_attn=args.train_attn,
train_mlp=args.train_mlp,
train_unembed=args.train_unembed,
)
class HatcheryTrainer(AxolotlTrainer):
"""Trainer that sends preprocessed batches to a remote training API.
Replaces local forward/backward with remote API calls to Tinker or
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
chat templates, packing, etc.) but offloads compute to remote GPUs.
"""
hatchery_args: Optional[HatcheryConfig]
_base_model_name: Optional[str]
_training_client: Any
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hatchery_args = None
self._base_model_name = None
self._training_client = None
def _get_training_client(self):
"""Lazily create the remote training session."""
if self._training_client is not None:
return self._training_client
args = self.hatchery_args
if args is None:
raise RuntimeError(
"HatcheryTrainer.hatchery_args not set. "
"Ensure the HatcheryPlugin is registered."
)
base_model = self._base_model_name
if not base_model:
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
self._training_client = _create_training_client(args, base_model)
LOG.info(
f"Remote training session created: backend={args.backend}, "
f"model={base_model}, rank={args.lora_rank}"
)
return self._training_client
def _send_batch(self, batch: dict[str, torch.Tensor]):
"""Convert batch to datums and send forward_backward to remote.
Returns (future, n_active_tokens) where n_active_tokens counts
the completion tokens in this batch (for loss normalization).
"""
input_ids = batch["input_ids"]
labels = batch["labels"]
attention_mask = batch.get("attention_mask")
n_active = int((labels[:, 1:] != -100).sum().item())
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
send_datums = datums_to_tinker(datums)
future = tc.forward_backward(
send_datums,
loss_fn=args.loss_fn,
loss_fn_config=args.loss_fn_config,
)
return future, n_active
def _do_optim_step(self):
"""Send optimizer step to remote using axolotl's training params."""
import tinker.types as tt
tc = self._get_training_client()
return tc.optim_step(tt.AdamParams(**self._optim_params))
def train(
self,
resume_from_checkpoint: Optional[str] = None,
trial: Any = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
) -> TrainOutput:
"""Main training loop — sends batches to remote API."""
args = self.hatchery_args
if args is None:
raise RuntimeError("hatchery_args not configured")
train_dataloader = self.get_train_dataloader()
num_batches = len(train_dataloader)
grad_accum = self.args.gradient_accumulation_steps
num_train_epochs = int(self.args.num_train_epochs)
steps_per_epoch = max(num_batches // grad_accum, 1)
max_steps = (
self.args.max_steps
if self.args.max_steps > 0
else steps_per_epoch * num_train_epochs
)
LOG.info(
f"Remote training: {num_batches} batches/epoch, "
f"{grad_accum} grad_accum, {max_steps} max steps, "
f"{num_train_epochs} epochs"
)
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = True
self.state.is_world_process_zero = True
self.control = self.callback_handler.on_train_begin(
self.args,
self.state,
self.control, # type: ignore[has-type]
)
global_step = 0
total_loss = 0.0
start_time = time.time()
for _epoch in range(num_train_epochs):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_epoch_begin(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
for batch_idx, batch in enumerate(train_dataloader):
if global_step >= max_steps:
break
self.control = self.callback_handler.on_step_begin(
self.args, self.state, self.control
)
fb_future, n_active = self._send_batch(batch)
pending_fb_futures.append((fb_future, n_active))
accum_count += 1
if accum_count >= grad_accum:
step_loss_sum = 0.0
step_active = 0
for fut, n_act in pending_fb_futures:
result = fut.result(timeout=args.future_timeout)
step_loss_sum += _extract_loss(result)
step_active += n_act
optim_future = self._do_optim_step()
if not args.pipeline:
optim_future.result(timeout=args.future_timeout)
step_loss = (
step_loss_sum / step_active
if step_active > 0
else step_loss_sum
)
global_step += 1
total_loss += step_loss
self.state.global_step = global_step
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
log_interval = self.args.logging_steps or 1
if global_step % log_interval == 0:
elapsed = time.time() - start_time
avg_loss = total_loss / global_step
LOG.info(
f"[step {global_step}/{max_steps}] "
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
f"active={step_active} "
f"{elapsed / global_step:.2f}s/step"
)
self.log(
{
"loss": step_loss,
"learning_rate": self._optim_params["learning_rate"],
"epoch": self.state.epoch,
}
)
if args.save_steps and global_step % args.save_steps == 0:
self._save_remote_checkpoint(global_step)
self.control = self.callback_handler.on_step_end(
self.args, self.state, self.control
)
pending_fb_futures = []
accum_count = 0
if self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(
self.args, self.state, self.control
)
if self.control.should_training_stop:
break
if global_step > 0:
self._save_remote_checkpoint(global_step, name="final")
elapsed = time.time() - start_time
avg_loss = total_loss / max(global_step, 1)
LOG.info(
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
)
self.control = self.callback_handler.on_train_end(
self.args, self.state, self.control
)
return TrainOutput(
global_step=global_step,
training_loss=avg_loss,
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
)
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
"""Save a checkpoint on the remote service."""
tc = self._get_training_client()
args = self.hatchery_args
assert args is not None # validated by _get_training_client
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
try:
future = tc.save_state(ckpt_name)
future.result(timeout=args.future_timeout)
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
except Exception:
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
if name == "final":
raise
def save_model(self, output_dir=None, _internal_call=False):
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
self._save_remote_checkpoint(
step=self.state.global_step,
name=output_dir or "hf-save",
)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
raise NotImplementedError(
"HatcheryTrainer uses remote API; compute_loss should not be called."
)

View File

@@ -2,17 +2,35 @@
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
validate_scattermoe_lora_shapes,
)
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
"validate_scattermoe_lora_shapes",
]
try:
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
except ModuleNotFoundError as exc:
if exc.name != "triton":
raise
else:
__all__ += [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -35,46 +35,19 @@ import torch
from torch import nn
from torch.nn import functional as F
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
)
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout.
peft >=0.19.1 assigns in/out features for 3D params such that
A and B already align with scattermoe's convention (no A<->B swap).
Only B needs rank-major → expert-major layout conversion.
"""
smoe_A = peft_A
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
__all__ = [
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
]
# =============================================================================
# ParamWrapper unwrapping
@@ -164,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj LoRA
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
@@ -173,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
# Extract down_proj LoRA
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Pure tensor layout helpers for ScatterMoE LoRA weights."""
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout.
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``out_features=dim1, in_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``), giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``.
peft gives:
lora_A ``[r*E, dim2]``, lora_B ``[dim1, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
peft's A already matches ScatterMoE's A shape. Only B needs conversion from
peft's rank-major layout to ScatterMoE's expert-major layout.
"""
smoe_A = peft_A
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
def validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B):
"""Validate LoRA tensor layout before dispatching ScatterMoE kernels."""
E, K, N = expert_weights.shape
if lora_A.dim() != 2 or lora_B.dim() != 2:
raise ValueError(
"ScatterMoE LoRA expects 2D lora_A and lora_B tensors, got "
f"lora_A={tuple(lora_A.shape)} and lora_B={tuple(lora_B.shape)}."
)
if lora_A.size(0) % E != 0:
raise ValueError(
"ScatterMoE LoRA expects lora_A rows to be divisible by the number "
f"of experts ({E}), got lora_A={tuple(lora_A.shape)}."
)
rank = lora_A.size(0) // E
expected_A = (E * rank, K)
expected_B = (N, E * rank)
if tuple(lora_A.shape) != expected_A or tuple(lora_B.shape) != expected_B:
raise ValueError(
"Invalid ScatterMoE LoRA layout for expert_weights "
f"{tuple(expert_weights.shape)}. Expected lora_A={expected_A} and "
f"lora_B={expected_B}, got lora_A={tuple(lora_A.shape)} and "
f"lora_B={tuple(lora_B.shape)}. For PEFT target_parameters, keep "
"lora_A as [E*r, K] and only convert lora_B from rank-major to "
"expert-major layout."
)

View File

@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
from .lora_layout import validate_scattermoe_lora_shapes
class ScatterMoELoRA(torch.autograd.Function):
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
@@ -451,6 +447,7 @@ def parallel_linear_lora(
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
return ScatterMoELoRA.apply(
inputs,
expert_weights,

View File

@@ -170,6 +170,7 @@ class PatchManager:
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
@@ -422,16 +423,7 @@ class PatchManager:
patch_gemma4_fused_attn,
)
# Shared-KV side channel when activation checkpointing (PR #3611).
fsdp_cfg = self.cfg.fsdp_config
needs_shared_kv_workaround = (not self.inference) and bool(
self.cfg.gradient_checkpointing
or self.cfg.activation_offloading
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
)
patch_gemma4_fused_attn(
install_shared_kv_workaround=needs_shared_kv_workaround
)
patch_gemma4_fused_attn()
@staticmethod
def _fix_nemotron_h_conversion_mapping():
@@ -709,10 +701,24 @@ class PatchManager:
)
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
"""Apply Flash Attention patches for LLaMA models."""
@@ -779,6 +785,23 @@ class PatchManager:
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(peft_model=model)
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(peft_model=model, cfg=self.cfg)
if self.cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
def _apply_lora_kernel_patch(self, model):
"""Apply LoRA kernel patches."""
if (

View File

@@ -23,8 +23,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.processor_kwargs:
processor_kwargs.update(cfg.processor_kwargs)
if cfg.tokenizer_use_mistral_common:

View File

@@ -6,29 +6,15 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
patch_gemma4_fused_attn()
"""
import logging
from typing import Callable
import torch
from axolotl.utils.logging import get_logger
logger = get_logger(__name__)
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
def _set_shared_kv_states(store):
_GEMMA4_SHARED_KV_STORE["store"] = store
def _get_shared_kv_states():
return _GEMMA4_SHARED_KV_STORE["store"]
logger = logging.getLogger(__name__)
def _make_fused_forward(original_forward):
@@ -44,7 +30,7 @@ def _make_fused_forward(original_forward):
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
@@ -53,10 +39,6 @@ def _make_fused_forward(original_forward):
eager_attention_forward,
)
store = _get_shared_kv_states()
if store is not None:
shared_kv_states = store
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
@@ -151,44 +133,15 @@ def _make_fused_forward(original_forward):
return fused_forward
def _patch_decoder_layer_call():
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
def patch_gemma4_fused_attn():
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
return
original_call = Gemma4TextDecoderLayer.__call__
def patched_call(self, *args, **kwargs):
shared_kv = kwargs.pop("shared_kv_states", None)
# Overwrite unconditionally (including with None) so a previous step's
# dict cannot leak into a later call without shared_kv_states (PR #3611).
_set_shared_kv_states(shared_kv)
return original_call(self, *args, **kwargs)
Gemma4TextDecoderLayer.__call__ = patched_call
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
and optionally route `shared_kv_states` via a module-level side channel to
avoid a VRAM leak under activation checkpointing (PR #3611).
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
if install_shared_kv_workaround:
_patch_decoder_layer_call()
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)
if install_shared_kv_workaround:
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")

View File

@@ -0,0 +1,252 @@
"""module for patching with unsloth optimizations"""
import inspect
import types
import torch
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
""".lstrip("\n")
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
""".lstrip("\n")
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip("\n")
PATCHED_O_CODE = """
attn_output = self.apply_o(self, attn_output)
""".lstrip("\n")
def original_apply_qkv(self, hidden_states):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self, hidden_states):
attn_output = self.o_proj(hidden_states)
return attn_output
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
def check_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
qkv, _ = detab_code(qkv)
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
from transformers.loss import loss_utils
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
else:
raise ValueError("Unsupported model type")
self_attn_lora_patched = False
def patch_self_attn_lora():
global self_attn_lora_patched
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def unsloth_attn_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in self_attn_forward:
items_to_import.append(item)
exec(
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(self_attn_forward, globals())
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora")
LlamaFlashAttention2.forward = unsloth_attn_forward
def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding
def apply_rotary_pos_emb(
q,
k,
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings")
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
apply_lora_mlp = apply_lora_mlp_swiglu
elif peft_model.base_model.config.model_type == "gemma":
from unsloth.kernels import apply_lora_mlp_geglu_approx
apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(
f"Model type {peft_model.base_model.config.model_type} not supported"
)
for idx, layer in enumerate(peft_model.model.model.layers):
layer_modules = [
getattr(layer.mlp, linear_proj)
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
]
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
mlp_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
mlp_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
from unsloth.kernels import apply_lora_o, apply_lora_qkv
for idx, layer in enumerate(peft_model.model.model.layers):
if cfg.unsloth_lora_qkv:
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
qkv_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
qkv_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
o_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
o_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_o_lora and o_no_bias and o_not_dora:
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)
LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")

View File

@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
@@ -1004,13 +1004,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if tools is None:
return None
# Some datasets have tools set to str
if isinstance(tools, str):
try:
tools = json.loads(tools)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
raise
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
@@ -1041,22 +1034,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing messages as JSON. Error: {e}")
raise
assert isinstance(messages, list), (
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
)
# Extra check here to make sure decoded json is a list of dicts.
for i, message in enumerate(messages):
assert isinstance(message, dict), (
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
)
if isinstance(messages, list):
return messages

View File

@@ -309,16 +309,6 @@ class AxolotlInputConfig(
dpo_padding_free: bool | None = None
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "List of DPO losses to use."},
)
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
default=None,
json_schema_extra={"description": "Weights for each DPO loss."},
)
datasets: (
Annotated[
list[
@@ -833,6 +823,13 @@ class AxolotlInputConfig(
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None
unsloth_lora_o: bool | None = None
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
lora_mlp_kernel: bool | None = Field(
default=None,
json_schema_extra={
@@ -1472,6 +1469,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
@@ -1525,7 +1537,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set or using 8-bit
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
@@ -1534,6 +1547,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):

View File

@@ -64,12 +64,6 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
processor_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
},
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
@@ -113,22 +107,6 @@ class ModelInputConfig(BaseModel):
)
return trust_remote_code
@field_validator("processor_kwargs")
@classmethod
def reject_reserved_processor_kwargs(cls, processor_kwargs):
if not processor_kwargs:
return processor_kwargs
reserved = {"revision", "trust_remote_code"}
conflicts = reserved.intersection(processor_kwargs)
if conflicts:
raise ValueError(
"Do not set reserved keys "
f"{sorted(conflicts)} inside `processor_kwargs`; "
"use the top-level `revision_of_model` / `trust_remote_code` "
"config keys instead."
)
return processor_kwargs
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""

View File

@@ -52,26 +52,6 @@ class DatasetValidationMixin:
return datasets
@model_validator(mode="before")
@classmethod
def check_deprecated_unsloth_fields(cls, data):
deprecated_fields = [
"unsloth_cross_entropy_loss",
"unsloth_lora_mlp",
"unsloth_lora_qkv",
"unsloth_lora_o",
"unsloth_rms_norm",
"unsloth_rope",
]
found = [f for f in deprecated_fields if data.get(f)]
if found:
raise ValueError(
f"`{'`, `'.join(found)}` {'has' if len(found) == 1 else 'have'} been removed. "
"Please use `lora_mlp_kernel`, `lora_qkv_kernel`, `lora_o_kernel` instead. "
"See: https://docs.axolotl.ai/docs/lora_optims.html"
)
return data
@model_validator(mode="before")
@classmethod
def check_dataset_or_pretraining_dataset(cls, data):
@@ -578,11 +558,6 @@ class TrainingValidationMixin:
"Setting chat_template is not supported with mistral-common tokenizer"
)
if data.get("processor_kwargs"):
raise ValueError(
"processor_kwargs is not supported with mistral-common tokenizer"
)
return data
@model_validator(mode="before")
@@ -632,6 +607,36 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora:
raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
)
return data
@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
@@ -765,40 +770,6 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_dpo(cls, data):
dpo_loss_type = data.get("dpo_loss_type")
dpo_loss_weights = data.get("dpo_loss_weights")
rl = data.get("rl")
if rl == "ipo":
LOG.warning(
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
)
if rl == "dpo":
if dpo_loss_weights is not None and dpo_loss_type is None:
raise ValueError(
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
)
if (
dpo_loss_type is not None
and dpo_loss_weights is not None
and len(dpo_loss_type) != len(dpo_loss_weights)
):
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
)
elif dpo_loss_type is not None or dpo_loss_weights is not None:
raise ValueError(
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
)
return data
@model_validator(mode="before")
@classmethod
def check_grpo_batch_size_divisibility(cls, data):
@@ -971,6 +942,17 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_cross_entropy_conflicts(cls, data):

View File

@@ -0,0 +1,102 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -118,52 +118,18 @@ def download_smollm2_135m_gptq_model():
snapshot_download_w_retry("lilmeaty/SmolLM2-135M-Instruct-GPTQ", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen3_half_billion_model():
# download the model (still used as the KD teacher in tests/e2e/integrations/test_kd.py)
# download the model
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_llama_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-llama-50m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_mistral_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-mistral-25m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_mixtral_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-mixtral-30m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_phi_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-phi-64m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_falcon_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-falcon-42m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_qwen2_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen2-129m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_qwen3_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen3-129m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_gemma2_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-gemma2-137m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
@@ -359,10 +325,10 @@ def download_phi_4_reasoning_model_fixture():
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_mini_model_fixture():
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@@ -654,15 +620,7 @@ def fixture_min_base_cfg():
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen3_half_billion_model,
download_tiny_llama_model,
download_tiny_mistral_model,
download_tiny_mixtral_model,
download_tiny_phi_model,
download_tiny_falcon_model,
download_tiny_qwen2_model,
download_tiny_qwen3_model,
download_tiny_gemma2_model,
download_qwen_2_5_half_billion_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,

View File

@@ -96,8 +96,6 @@ def fixture_dpo_cfg(base_cfg):
"dpo_use_weighting": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 0.5],
}
)
return cfg
@@ -166,8 +164,7 @@ def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.DPO,
"dpo_loss_type": ["ipo"],
"rl": RLType.IPO,
"dpo_label_smoothing": 0,
"beta": 0.1,
}
@@ -303,8 +300,6 @@ class TestHFRLTrainerBuilder:
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
assert training_arguments.precompute_ref_log_probs is True
assert training_arguments.loss_type == ["sigmoid", "sft"]
assert training_arguments.loss_weights == [1.0, 0.5]
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)

View File

@@ -10,10 +10,7 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
)
from tests.e2e.utils import check_model_output_exists
@pytest.fixture()
@@ -38,16 +35,13 @@ def min_cfg(temp_dir):
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 5e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"max_steps": 40,
"warmup_steps": 5,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
@@ -70,18 +64,11 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=2.2,
max_final=2.0,
)
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
@@ -100,15 +87,13 @@ class TestCutCrossEntropyIntegration:
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"max_steps": 50,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -123,13 +108,6 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.parametrize(
"attention_type",
@@ -158,10 +136,3 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=2.2,
max_final=2.0,
)

View File

@@ -54,9 +54,7 @@ except (ImportError, ModuleNotFoundError):
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
smoe_A = peft_A
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
return smoe_A, smoe_B
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
def _unwrap_experts_lora(experts_module):
return experts_module, None, None
@@ -129,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
Both gate_up_proj and down_proj need the A<->B swap because
scattermoe transposes the parameter (W = param.T).
"""
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
@@ -306,8 +300,6 @@ class TestLoRABLayoutConversion:
hidden, inter = 32, 16
scaling = 2.0
# peft >=0.19.1 for down_proj [E, hidden, inter]:
# swaps in/out, lora_A [r*E, inter], lora_B [hidden, r*E]
peft_A = torch.randn(E * r, inter)
peft_B = torch.randn(hidden, E * r)
@@ -316,6 +308,8 @@ class TestLoRABLayoutConversion:
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A.shape == (E * r, inter)
assert smoe_B.shape == (hidden, E * r)
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
@@ -325,22 +319,30 @@ class TestLoRABLayoutConversion:
)
def test_gate_up_proj_conversion(self):
"""Verify gate_up_proj LoRA conversion with non-square dims.
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
gate_up_proj param: [E, 2*inter, hidden].
peft swaps in/out for 3D: lora_A [r*E, hidden], lora_B [2*inter, r*E].
peft: in_features=hidden, out_features=2*inter.
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
"""
E, r = 4, 2
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
scaling = 2.0
peft_A = torch.randn(E * r, hidden) # [r*E, in=hidden]
peft_B = torch.randn(2 * inter, E * r) # [out=2*inter, r*E]
# peft assigns: in_features=hidden, out_features=2*inter
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(2 * inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
# = param[e] shape [2*inter, hidden]
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
@@ -398,7 +400,8 @@ class TestPeftLoRAWeightExtraction:
r,
)
# gate_up_proj [E, 2*inter, hidden] — peft swaps in/out for 3D
# gate_up_proj [E, 2*inter, hidden]
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
].shape == (E * r, config.hidden_size)
@@ -406,7 +409,8 @@ class TestPeftLoRAWeightExtraction:
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (2 * config.intermediate_size, E * r)
# down_proj [E, hidden, inter] — peft swaps in/out for 3D
# down_proj [E, hidden, inter]
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.intermediate_size)
@@ -463,26 +467,29 @@ class TestPeftLoRAWeightExtraction:
assert gup_lora is not None, "gate_up_proj LoRA not detected"
assert down_lora is not None, "down_proj LoRA not detected"
# gate_up_proj: K=hidden, N=2*inter
# Check shapes after peft->scattermoe conversion.
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
E, r = config.num_experts, 4
gup_A, gup_B, gup_s = gup_lora
assert gup_A.shape == (E * r, config.hidden_size), (
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
f"got {gup_A.shape}"
)
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
)
# down_proj: K=inter, N=hidden
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
down_A, down_B, down_s = down_lora
assert down_A.shape == (E * r, config.intermediate_size), (
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
f"got {down_A.shape}"
)
assert down_B.shape == (config.hidden_size, E * r), (
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
f"got {down_B.shape}"
)

View File

@@ -24,7 +24,7 @@ from axolotl.monkeypatch.lora_kernels import (
)
from axolotl.utils.dict import DictDefault
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
MODEL_NAME = "Qwen/Qwen3-0.6B"
DEVICE = "cuda"
DTYPE = torch.bfloat16

View File

@@ -1,22 +1,23 @@
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -29,13 +30,19 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestDistMuon:
@@ -45,7 +52,7 @@ class TestDistMuon:
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -56,12 +63,11 @@ class TestDistMuon:
},
],
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-3,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
@@ -76,9 +82,6 @@ class TestDistMuon:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -106,7 +109,7 @@ class TestDistMuon:
def test_lora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -119,15 +122,14 @@ class TestDistMuon:
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-3,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
@@ -142,9 +144,6 @@ class TestDistMuon:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)

View File

@@ -1,23 +1,24 @@
"""Test module for FSDP1 multi-GPU functionality."""
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard_loss_decreased
from tests.e2e.utils import most_recent_subdir
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -30,13 +31,19 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestFSDP1:
@@ -49,7 +56,7 @@ class TestFSDP1:
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -60,12 +67,11 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -81,9 +87,6 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -123,7 +126,7 @@ class TestFSDP1:
def test_lora_sft(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -137,15 +140,14 @@ class TestFSDP1:
"load_in_4bit": adapter_config["load_in_4bit"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -161,9 +163,6 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -191,7 +190,7 @@ class TestFSDP1:
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
@@ -204,11 +203,11 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 20,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -224,9 +223,6 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)
@@ -266,7 +262,7 @@ class TestFSDP1:
def test_dpo_lora(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"load_in_4bit": adapter_config["load_in_4bit"],
"rl": "dpo",
"chat_template": "chatml",
@@ -285,11 +281,11 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 20,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -305,9 +301,6 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": "auto",
"tf32": True,
}

View File

@@ -1,23 +1,24 @@
"""Test module for FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -30,13 +31,19 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestFSDP2:
@@ -50,7 +57,7 @@ class TestFSDP2:
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -61,12 +68,11 @@ class TestFSDP2:
},
],
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -80,9 +86,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -111,7 +114,7 @@ class TestFSDP2:
def test_lora_sft(self, temp_dir, peft_use_dora):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -125,15 +128,14 @@ class TestFSDP2:
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -147,9 +149,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
# explicitly disable LORA kernels, as they may be auto-enabled
"lora_mlp_kernel": False,
@@ -181,7 +180,7 @@ class TestFSDP2:
def test_lora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -196,12 +195,11 @@ class TestFSDP2:
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -215,9 +213,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
@@ -248,7 +243,7 @@ class TestFSDP2:
def test_qlora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -262,15 +257,14 @@ class TestFSDP2:
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -284,9 +278,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -314,7 +305,7 @@ class TestFSDP2:
def test_qlora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -330,12 +321,11 @@ class TestFSDP2:
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 80,
"warmup_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -349,9 +339,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
@@ -383,7 +370,7 @@ class TestFSDP2:
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
@@ -396,11 +383,11 @@ class TestFSDP2:
},
],
"num_epochs": 1,
"max_steps": 20,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -414,9 +401,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)
@@ -444,7 +428,7 @@ class TestFSDP2:
def test_dpo_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"rl": "dpo",
"chat_template": "chatml",
@@ -461,11 +445,11 @@ class TestFSDP2:
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 20,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -479,9 +463,6 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)

View File

@@ -40,7 +40,7 @@ def _run_training(temp_dir, cfg):
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "axolotl-ai-co/tiny-qwen3-129m",
"base_model": "Qwen/Qwen3-0.6B",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [

View File

@@ -8,7 +8,7 @@ from accelerate.test_utils import execute_subprocess_async, get_torch_dist_uniqu
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
class TestTensorParallel:
@@ -21,7 +21,7 @@ class TestTensorParallel:
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -63,6 +63,6 @@ class TestTensorParallel:
]
)
check_tensorboard_loss_decreased(
temp_dir + "/runs", max_initial=5.0, max_final=4.7
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -32,12 +32,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "axolotl-ai-co/tiny-mistral-25m",
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "axolotl-ai-co/tiny-qwen2-129m",
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -47,7 +47,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "axolotl-ai-co/tiny-gemma2-137m",
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -159,7 +159,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"axolotl-ai-co/tiny-gemma2-137m",
"trl-internal-testing/tiny-Gemma2ForCausalLM",
dtype=torch.float16,
device_map="cuda:0",
)

View File

@@ -4,16 +4,14 @@ E2E tests for falcon
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from ..utils import check_model_output_exists, with_temp_dir
class TestFalconPatched(unittest.TestCase):
@@ -21,12 +19,13 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
@@ -48,20 +47,17 @@ class TestFalconPatched(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -70,20 +66,14 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.05,
@@ -98,20 +88,17 @@ class TestFalconPatched(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -120,10 +107,3 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -9,12 +9,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
require_torch_2_6_0,
with_temp_dir,
)
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
class TestMistral(unittest.TestCase):
@@ -27,7 +22,7 @@ class TestMistral(unittest.TestCase):
def test_lora_packing(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -50,20 +45,17 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -72,19 +64,12 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.5,
max_final=4.3,
)
@with_temp_dir
def test_ft_packing(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -101,20 +86,17 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -123,10 +105,3 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.5,
max_final=4.3,
)

View File

@@ -9,11 +9,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from ..utils import check_model_output_exists, with_temp_dir
class TestMixtral(unittest.TestCase):
@@ -25,7 +21,8 @@ class TestMixtral(unittest.TestCase):
def test_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -33,7 +30,7 @@ class TestMixtral(unittest.TestCase):
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {},
@@ -44,21 +41,17 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 3e-3,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -67,19 +60,13 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=6.0,
max_final=4.7,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -92,21 +79,17 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 5e-4,
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -115,10 +98,3 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -22,7 +22,8 @@ class TestModelPatches(unittest.TestCase):
def test_mixtral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -56,7 +57,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -9,11 +9,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from ..utils import check_model_output_exists, with_temp_dir
class TestPhiMultipack(unittest.TestCase):
@@ -25,7 +21,7 @@ class TestPhiMultipack(unittest.TestCase):
def test_ft_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-phi-64m",
"base_model": "microsoft/phi-1_5",
"model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
@@ -47,20 +43,17 @@ class TestPhiMultipack(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"eval_steps": 50,
"save_steps": 50,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -70,19 +63,12 @@ class TestPhiMultipack(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)
@with_temp_dir
def test_qlora_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-phi-64m",
"base_model": "microsoft/phi-1_5",
"model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
@@ -108,20 +94,17 @@ class TestPhiMultipack(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"eval_steps": 50,
"save_steps": 50,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -131,10 +114,3 @@ class TestPhiMultipack(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -0,0 +1,21 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
import pytest
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_self_attn_is_patchable(),
"HF transformers self attention code has changed and isn't patchable",
)

View File

@@ -0,0 +1,184 @@
"""
e2e tests for unsloth qlora
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
"sdp_attention",
[True, False],
)
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"sdp_attention": sdp_attention,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -18,7 +18,7 @@ from transformers import AutoModelForCausalLM
# Import the actual trainer methods we want to test
from axolotl.core.trainers.grpo.async_trainer import AsyncGRPOTrainer
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
MODEL_NAME = "Qwen/Qwen3-0.6B"
def _fix_patched_attention(model):

View File

@@ -116,58 +116,6 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_rpo(self, temp_dir):
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
# replaces loss_type="rpo", rpo_alpha=alpha.
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"dpo_loss_type": ["sigmoid", "sft"],
"dpo_loss_weights": [1.0, 1.0],
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
@@ -233,8 +181,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"dpo_loss_type": ["ipo"],
"rl": "ipo",
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",

View File

@@ -4,16 +4,14 @@ E2E tests for falcon
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from .utils import check_model_output_exists, with_temp_dir
class TestFalcon(unittest.TestCase):
@@ -21,12 +19,13 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -50,21 +49,17 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -74,20 +69,14 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -115,21 +104,17 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -139,20 +124,14 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
@@ -166,23 +145,17 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"sample_packing": True,
"pad_to_sequence_len": True,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 5e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -192,10 +165,3 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)

View File

@@ -11,11 +11,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from .utils import check_model_output_exists, with_temp_dir
class TestMistral(unittest.TestCase):
@@ -27,7 +23,7 @@ class TestMistral(unittest.TestCase):
def test_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -49,18 +45,16 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -70,19 +64,12 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=4.5,
max_final=4.3,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,
@@ -98,18 +85,16 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -123,10 +108,3 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=4.5,
max_final=4.3,
)

View File

@@ -12,11 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from .utils import check_model_output_exists, with_temp_dir
class TestMixtral(unittest.TestCase):
@@ -28,7 +24,8 @@ class TestMixtral(unittest.TestCase):
def test_qlora_w_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -54,18 +51,16 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -79,19 +74,13 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -117,18 +106,16 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -142,19 +129,13 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"adapter": "lora",
@@ -179,18 +160,16 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -208,19 +187,13 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"adapter": "lora",
@@ -245,18 +218,16 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -274,19 +245,13 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,
@@ -298,18 +263,16 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -323,10 +286,3 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)

View File

@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
@@ -244,18 +243,20 @@ class TestCustomOptimizers(unittest.TestCase):
def test_came_pytorch(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-llama-50m",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -264,22 +265,16 @@ class TestCustomOptimizers(unittest.TestCase):
},
],
"num_epochs": 1,
"sample_packing": True,
"pad_to_sequence_len": True,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-4,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"max_steps": 5,
"lr_scheduler": "cosine",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -289,13 +284,6 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=4.0,
max_final=3.0,
)
@require_torch_2_7_0

View File

@@ -9,11 +9,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
from .utils import check_model_output_exists, with_temp_dir
class TestPhi(unittest.TestCase):
@@ -25,7 +21,7 @@ class TestPhi(unittest.TestCase):
def test_phi_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-phi-64m",
"base_model": "microsoft/phi-1_5",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
@@ -45,22 +41,18 @@ class TestPhi(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 4,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 10,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -69,19 +61,12 @@ class TestPhi(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_phi_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/tiny-phi-64m",
"base_model": "microsoft/phi-1_5",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
@@ -105,22 +90,18 @@ class TestPhi(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 4,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 2e-4,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"max_steps": 10,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -129,10 +110,3 @@ class TestPhi(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)

Some files were not shown because too many files have changed in this diff Show More