Compare commits
27 Commits
08fc7de87e
...
kernelize-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8495c79fb1 | ||
|
|
9a0d3016df | ||
|
|
e562e149ce | ||
|
|
9de5b76336 | ||
|
|
323da791eb | ||
|
|
6990478163 | ||
|
|
63a58cfec1 | ||
|
|
3985ec2f67 | ||
|
|
a44edda6d7 | ||
|
|
66c3e5a3fd | ||
|
|
b8358aa5ab | ||
|
|
e079cf16a2 | ||
|
|
e2f69828d2 | ||
|
|
122b50bad6 | ||
|
|
e77a185e86 | ||
|
|
29fa4dedbb | ||
|
|
315cdeede9 | ||
|
|
e7a6a5b529 | ||
|
|
bfb4da1d25 | ||
|
|
4dfa0a59b2 | ||
|
|
4ef608dda3 | ||
|
|
7daf7d96f1 | ||
|
|
7c56809c7f | ||
|
|
149178ddb7 | ||
|
|
dc638e723f | ||
|
|
6f15da4cac | ||
|
|
900eec7988 |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,7 +31,10 @@ PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
# Install axolotl + dev and test dependencies from lockfile
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/*.yml'
|
||||
- "*.[q]md"
|
||||
- "examples/**/*.y[a]?ml"
|
||||
|
||||
35
.github/workflows/multi-gpu-e2e.yml
vendored
35
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
- "tests/e2e/multigpu/**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/multi-gpu-e2e.yml"
|
||||
- "scripts/cutcrossentropy_install.py"
|
||||
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
|
||||
- "src/axolotl/utils/distributed.py"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
@@ -33,19 +31,19 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -53,7 +51,6 @@ jobs:
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -75,7 +72,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -41,11 +44,15 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install wheel packaging
|
||||
uv pip install --no-build-isolation -e .
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
|
||||
55
.github/workflows/tests-nightly.yml
vendored
55
.github/workflows/tests-nightly.yml
vendored
@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
- ".github/workflows/tests-nightly.yml"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
@@ -20,7 +23,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -43,7 +46,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
@@ -61,36 +64,34 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Override with nightly HF packages
|
||||
run: |
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -102,9 +103,6 @@ jobs:
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
@@ -136,7 +134,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -157,7 +154,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
|
||||
95
.github/workflows/tests.yml
vendored
95
.github/workflows/tests.yml
vendored
@@ -6,21 +6,19 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
@@ -33,6 +31,7 @@ permissions:
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -44,7 +43,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -94,32 +93,25 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -188,38 +180,42 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
uv pip install packaging setuptools_scm build wheel psutil
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Verify agent docs are discoverable
|
||||
run: |
|
||||
# Agent docs live in docs/agents/ (source of truth) and are resolved
|
||||
# at runtime from the repo checkout or via `axolotl fetch docs`
|
||||
axolotl agent-docs --list
|
||||
axolotl agent-docs | grep -q "Fine-tuning framework"
|
||||
axolotl agent-docs grpo | grep -q "GRPO"
|
||||
axolotl agent-docs sft | grep -q "SFT"
|
||||
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
|
||||
@@ -281,7 +277,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -302,7 +297,7 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
@@ -364,7 +359,7 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
|
||||
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
||||
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
||||
axolotl fetch examples # Download example configs
|
||||
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
|
||||
axolotl agent-docs grpo # Topic-specific agent reference
|
||||
axolotl config-schema # Dump config JSON schema
|
||||
```
|
||||
|
||||
## Training Methods
|
||||
@@ -35,6 +38,8 @@ Agent-specific references:
|
||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
|
||||
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
|
||||
|
||||
## Config Pattern
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include VERSION
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
include AGENTS.md
|
||||
recursive-include docs/agents *.md
|
||||
recursive-include axolotl *.py
|
||||
|
||||
43
README.md
43
README.md
@@ -86,7 +86,7 @@ Features:
|
||||
**Requirements**:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- Python >=3.11 (3.12 recommended)
|
||||
- PyTorch ≥2.9.1
|
||||
|
||||
### Google Colab
|
||||
@@ -95,11 +95,19 @@ Features:
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
# install uv if you don't already have it installed (restart shell after)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# change depending on system
|
||||
export UV_TORCH_BACKEND=cu128
|
||||
|
||||
# create a new virtual environment
|
||||
uv venv --python 3.12
|
||||
source .venv/bin/activate
|
||||
|
||||
uv pip install torch==2.10.0 torchvision
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
@@ -110,7 +118,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -157,6 +165,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
|
||||
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
|
||||
|
||||
## AI Agent Support
|
||||
|
||||
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
|
||||
|
||||
```bash
|
||||
# Show overview and available training methods
|
||||
axolotl agent-docs
|
||||
|
||||
# Topic-specific references
|
||||
axolotl agent-docs sft # supervised fine-tuning
|
||||
axolotl agent-docs grpo # GRPO online RL
|
||||
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
|
||||
axolotl agent-docs reward_modelling # outcome and process reward models
|
||||
axolotl agent-docs pretraining # continual pretraining
|
||||
axolotl agent-docs --list # list all topics
|
||||
|
||||
# Dump config schema for programmatic use
|
||||
axolotl config-schema
|
||||
axolotl config-schema --field adapter
|
||||
```
|
||||
|
||||
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
||||
|
||||
@@ -134,7 +134,6 @@ quartodoc:
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||
@@ -327,7 +326,6 @@ website:
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py --uv | sh
|
||||
# Override with nightly HF packages for nightly builds
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
|
||||
fi
|
||||
|
||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
|
||||
|
||||
set -o pipefail
|
||||
for i in 1 2 3; do
|
||||
|
||||
@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
fi && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -33,7 +33,6 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
198
docs/agents/model_architectures.md
Normal file
198
docs/agents/model_architectures.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# Model Architectures — Agent Reference
|
||||
|
||||
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||
|
||||
## VLM (Vision Language Model) Quick Start
|
||||
|
||||
All VLM configs require these four lines:
|
||||
```yaml
|
||||
processor_type: AutoProcessor
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
Decision tree for VLM config:
|
||||
```text
|
||||
Is the model multimodal (has vision/audio encoder)?
|
||||
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||
└─ NO: Train as a regular text model
|
||||
|
||||
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||
│ Consider ScatterMoE kernels (see Plugins section)
|
||||
└─ NO: Standard LoRA config
|
||||
```
|
||||
|
||||
## Plugins & Optimizations
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||
|
||||
```bash
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
```
|
||||
|
||||
### ScatterMoE Kernels
|
||||
|
||||
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
|
||||
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||
|
||||
## Gemma 4
|
||||
|
||||
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||
|
||||
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
|
||||
|
||||
### Required settings
|
||||
|
||||
```yaml
|
||||
# Always needed for Gemma4:
|
||||
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
|
||||
|
||||
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
```
|
||||
|
||||
### Auto-detection
|
||||
|
||||
Axolotl auto-detects Gemma4 and applies:
|
||||
- `use_reentrant: false` for gradient checkpointing
|
||||
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
| Strategy | Works? | Notes |
|
||||
|----------|--------|-------|
|
||||
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
|
||||
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
|
||||
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
|
||||
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
|
||||
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
|
||||
|
||||
FSDP2 config:
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
|
||||
```
|
||||
|
||||
### MoE (26B-A4B)
|
||||
|
||||
- `enable_moe_block: true`, 256 experts, top-k routing
|
||||
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
|
||||
- Expert LoRA targets 3D parameter tensors:
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
- ScatterMoE kernel acceleration:
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
```
|
||||
|
||||
### VLM (Vision) Training
|
||||
|
||||
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
chat_template: gemma4
|
||||
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
```
|
||||
|
||||
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||
|
||||
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
### Common issues
|
||||
|
||||
| Symptom | Cause | Fix |
|
||||
|---------|-------|-----|
|
||||
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
|
||||
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
|
||||
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
|
||||
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
|
||||
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
|
||||
|
||||
### E2B/E4B dense models
|
||||
|
||||
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
|
||||
|
||||
## Gemma 3
|
||||
|
||||
**Models**: `google/gemma-3-*`
|
||||
|
||||
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
|
||||
- `use_reentrant: false` recommended
|
||||
- Attention mask must be dropped for sample packing (handled automatically)
|
||||
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
|
||||
|
||||
## Qwen 3.5 MoE
|
||||
|
||||
**Models**: `Qwen/Qwen3.5-35B-A3B`
|
||||
|
||||
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
|
||||
- 256 experts, 8 active per token
|
||||
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
|
||||
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
```
|
||||
|
||||
## General MoE Notes
|
||||
|
||||
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
|
||||
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
|
||||
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin
|
||||
181
docs/agents/new_model_support.md
Normal file
181
docs/agents/new_model_support.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# New Model Support — Agent Reference
|
||||
|
||||
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
|
||||
|
||||
## Quick Validation Checklist
|
||||
|
||||
When testing a new model, run through these checks in order:
|
||||
|
||||
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
|
||||
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
|
||||
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
|
||||
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
|
||||
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
|
||||
|
||||
## Loss Debugging
|
||||
|
||||
### Expected initial loss
|
||||
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
|
||||
|
||||
### Direct comparison technique
|
||||
The fastest way to isolate a loss issue — bypass the trainer entirely:
|
||||
|
||||
```python
|
||||
# Load model via axolotl's pipeline (applies all patches)
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
|
||||
cfg = load_cfg("your_config.yaml")
|
||||
normalize_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
# Forward pass on preprocessed data
|
||||
model.train()
|
||||
out = model(input_ids, labels=labels)
|
||||
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
|
||||
```
|
||||
|
||||
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
|
||||
|
||||
### `model_accepts_loss_kwargs` inflation
|
||||
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
|
||||
|
||||
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
|
||||
|
||||
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
|
||||
|
||||
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
|
||||
|
||||
## Multimodal Models (ForConditionalGeneration)
|
||||
|
||||
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
|
||||
- Gemma3 → `Gemma3ForConditionalGeneration`
|
||||
- Gemma4 → `Gemma4ForConditionalGeneration`
|
||||
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
|
||||
- LLaVA → `LlavaForConditionalGeneration`
|
||||
|
||||
### Why this matters
|
||||
|
||||
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|
||||
|-----------|----------------------|--------------------------------|
|
||||
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
|
||||
| PEFT LoRA | ✅ | May fail on custom layer types |
|
||||
| HF Trainer label handling | ✅ | May need extra inputs |
|
||||
|
||||
### Required extra inputs
|
||||
Multimodal models require special inputs during training even for text-only data:
|
||||
|
||||
| Model | Required Input | Value for Text-Only |
|
||||
|-------|---------------|-------------------|
|
||||
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||
|
||||
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
|
||||
|
||||
### Custom layer types and PEFT
|
||||
Vision towers often use custom module wrappers that PEFT doesn't support:
|
||||
|
||||
| Model | Custom Layer | Wraps | Fix |
|
||||
|-------|-------------|-------|-----|
|
||||
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
|
||||
|
||||
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
|
||||
|
||||
## Sample Packing
|
||||
|
||||
### How packed sequence detection works (transformers ≥ 5.x)
|
||||
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
|
||||
|
||||
```python
|
||||
# From masking_utils.py:
|
||||
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||||
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||||
```
|
||||
|
||||
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
|
||||
|
||||
### Fix for models using `create_causal_mask_mapping`
|
||||
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
|
||||
|
||||
```python
|
||||
# In compute_loss():
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and model_type in ("gemma4", "gemma3")
|
||||
and "attention_mask" in inputs
|
||||
and "position_ids" in inputs
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
```
|
||||
|
||||
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
|
||||
|
||||
### Models that DON'T need this fix
|
||||
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
|
||||
|
||||
## Attention Backend Selection
|
||||
|
||||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||
|---------|--------|---------------|---------------|-------|
|
||||
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
||||
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
||||
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||
| eager | neither set | None | ✅ | Slowest, always works |
|
||||
|
||||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||
|
||||
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
|
||||
|
||||
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
|
||||
|
||||
## Cut Cross Entropy (CCE)
|
||||
|
||||
### How CCE patches work
|
||||
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
|
||||
|
||||
### Adding CCE for a new model
|
||||
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
|
||||
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
|
||||
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
|
||||
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
|
||||
|
||||
### Common CCE pitfall
|
||||
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
|
||||
|
||||
## MoE Models
|
||||
|
||||
### Dense MLP vs MoE experts
|
||||
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
|
||||
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
|
||||
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
|
||||
|
||||
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
|
||||
|
||||
### ScatterMoE kernels
|
||||
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
```
|
||||
|
||||
## Where to Add Model-Specific Fixes
|
||||
|
||||
| What | Where | Example |
|
||||
|------|-------|---------|
|
||||
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
|
||||
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
|
||||
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
|
||||
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
|
||||
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
|
||||
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
|
||||
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
|
||||
| Example configs | `examples/<model>/` | Validated YAML |
|
||||
| Config validation | `utils/schemas/validation.py` | Compatibility checks |
|
||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
|
||||
## Profiling
|
||||
|
||||
To profile training and identify optimization opportunities:
|
||||
|
||||
```yaml
|
||||
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||
profiler_steps_start: 3
|
||||
profiler_steps: 5
|
||||
```
|
||||
|
||||
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||
View the Chrome trace at `chrome://tracing`.
|
||||
|
||||
To programmatically inspect the trace:
|
||||
```bash
|
||||
python scripts/analyze_profile.py output_dir/
|
||||
```
|
||||
|
||||
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||
- **Large matmul kernels**: candidates for fusion or quantization
|
||||
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||
- **Small frequent kernels**: candidates for kernel fusion
|
||||
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
@@ -108,6 +108,14 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
|
||||
|
||||
```yaml
|
||||
chat_template_jinja: ./path/to/my_template.jinja2
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
:::
|
||||
@@ -294,6 +302,113 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
#### Content parts with per-part training control
|
||||
|
||||
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me think step by step...", "train": false},
|
||||
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The configuration is the same as standard `chat_template` — no extra fields needed:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
Each content part supports:
|
||||
|
||||
- `type`: `"text"` (required)
|
||||
- `text`: the text value (also accepts `content` or `value` as the key)
|
||||
- `train`: `true`/`false` (optional) — whether to train on this part
|
||||
- `weight`: `0`/`1` (optional) — alternative to `train`
|
||||
|
||||
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
|
||||
|
||||
::: {.callout-warning title="Whitespace at part boundaries"}
|
||||
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
|
||||
|
||||
**Split BEFORE spaces** (space goes with the next part):
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Let me think...", "train": false},
|
||||
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Let me think... ", "train": false},
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
|
||||
|
||||
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Thinking:\n", "train": false},
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
|
||||
:::
|
||||
|
||||
##### Per-part training on reasoning_content
|
||||
|
||||
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": [
|
||||
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
|
||||
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
|
||||
],
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
|
||||
|
||||
::: {.callout-tip}
|
||||
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
|
||||
:::
|
||||
|
||||
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
|
||||
|
||||
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
@@ -76,8 +76,9 @@ datasets:
|
||||
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -208,17 +209,17 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -6,23 +6,30 @@ format:
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
This section describes the different Docker images that are released by AxolotlAI at
|
||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
||||
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
||||
:::
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
|
||||
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-base
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
|
||||
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
@@ -32,8 +39,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
- `main-base-py3.12-cu128-2.10.0`
|
||||
- `main-base-py3.12-cu130-2.9.1`
|
||||
- `main-base-py3.12-cu130-2.10.0`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -41,11 +50,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
|
||||
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
|
||||
|
||||
#### Tags format {#sec-main-tags}
|
||||
|
||||
@@ -53,7 +61,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
# on push to main
|
||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
||||
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
|
||||
main-latest
|
||||
|
||||
# nightly build
|
||||
@@ -71,11 +79,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-py3.12-cu128-2.10.0`
|
||||
- `main-py3.12-cu130-2.9.1`
|
||||
- `main-py3.12-cu130-2.10.0`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `main-20260315-py3.11-cu128-2.9.1`
|
||||
- `0.12.0`
|
||||
|
||||
## Cloud
|
||||
@@ -90,11 +99,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-cloud
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
|
||||
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
|
||||
@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.9.0
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
## Installation {#sec-installation}
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
### Quick Install {#sec-uv}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
|
||||
|
||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
||||
installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### uv Installation {#sec-uv}
|
||||
|
||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
||||
|
||||
Install uv if not already installed
|
||||
Install uv if not already installed:
|
||||
```{.bash}
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
```
|
||||
|
||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
||||
then create the venv and activate
|
||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu126
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Install PyTorch
|
||||
- PyTorch 2.6.0 recommended
|
||||
```{.bash}
|
||||
uv pip install packaging setuptools wheel
|
||||
uv pip install torch==2.6.0
|
||||
uv pip install awscli pydantic
|
||||
```
|
||||
|
||||
Install axolotl from PyPi
|
||||
```{.bash}
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
||||
|
||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
@@ -82,14 +48,17 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
||||
|
||||
### Docker {#sec-docker}
|
||||
|
||||
```{.bash}
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
For development with Docker:
|
||||
@@ -106,12 +75,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||
axolotlai/axolotl:main-latest
|
||||
axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
@@ -122,7 +91,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
|
||||
|
||||
For providers supporting Docker:
|
||||
|
||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||
- Use `axolotlai/axolotl-cloud-uv:main-latest`
|
||||
- Available on:
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||
@@ -141,7 +110,7 @@ For providers supporting Docker:
|
||||
### macOS {#sec-macos}
|
||||
|
||||
```{.bash}
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
```
|
||||
|
||||
See @sec-troubleshooting for Mac-specific issues.
|
||||
@@ -152,21 +121,44 @@ See @sec-troubleshooting for Mac-specific issues.
|
||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
:::
|
||||
|
||||
## Environment Managers {#sec-env-managers}
|
||||
## Migrating from pip to uv {#sec-migrating}
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
If you have an existing pip-based Axolotl installation, you can migrate to uv:
|
||||
|
||||
1. Install Python ≥3.11
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
hf auth login
|
||||
```
|
||||
```{.bash}
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# Create a fresh venv (recommended for a clean start)
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
|
||||
# Reinstall axolotl
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
## Using pip (Alternative) {#sec-pip}
|
||||
|
||||
If you are unable to install uv, you can still use pip directly.
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have PyTorch installed before installing Axolotl with pip.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
For editable/development installs:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ format:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [Gemma-4](#sec-gemma-4) *(NEW)*
|
||||
- [Mllama](#sec-mllama)
|
||||
- [Llama4](#sec-llama4)
|
||||
- [Pixtral](#sec-pixtral)
|
||||
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
processor_type: VoxtralProcessor
|
||||
```
|
||||
|
||||
### Gemma-4 {#sec-gemma-4}
|
||||
|
||||
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
|
||||
|
||||
chat_template: gemma4
|
||||
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
|
||||
|
||||
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
|
||||
lora_target_parameters:
|
||||
- experts.gate_up_proj
|
||||
- experts.down_proj
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
|
||||
:::
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
---
|
||||
title: "Unsloth"
|
||||
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||
---
|
||||
|
||||
### Overview
|
||||
|
||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||
standard industry baselines.
|
||||
|
||||
::: {.callout-important}
|
||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
||||
|
||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
||||
:::
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install the correct unsloth and extras from source.
|
||||
|
||||
```bash
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||
|
||||
Our unsloth integration is currently limited to the following model architectures:
|
||||
- llama
|
||||
|
||||
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||
```yaml
|
||||
unsloth_lora_mlp: true
|
||||
unsloth_lora_qkv: true
|
||||
unsloth_lora_o: true
|
||||
```
|
||||
|
||||
These options are composable and can be used with multi-gpu finetuning
|
||||
```yaml
|
||||
unsloth_cross_entropy_loss: true
|
||||
unsloth_rms_norm: true
|
||||
unsloth_rope: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- Single GPU only; e.g. no multi-gpu support
|
||||
- No deepspeed or FSDP support (requires multi-gpu)
|
||||
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||
- No MoE support.
|
||||
@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
uv pip uninstall causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
uv pip install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -17,8 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -16,8 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -26,8 +26,8 @@ output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
- ^model.language_model.*
|
||||
- ^lm_head.*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
|
||||
@@ -26,8 +26,8 @@ output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
- ^model.language_model.*
|
||||
- ^lm_head.*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
|
||||
@@ -22,8 +22,8 @@ output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
- ^model.language_model.*
|
||||
- ^lm_head.*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -10,17 +10,16 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
```bash
|
||||
pip3 install timm==1.0.17
|
||||
uv pip install timm==1.0.17
|
||||
|
||||
# for loading audio data
|
||||
pip3 install librosa==0.11.0
|
||||
uv pip install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Download sample dataset files
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
|
||||
#
|
||||
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
|
||||
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
|
||||
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
|
||||
#
|
||||
# Key notes:
|
||||
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
|
||||
# Use sdp_attention instead.
|
||||
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
|
||||
# LoRA to the text backbone via lora_target_linear_modules regex.
|
||||
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
|
||||
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
|
||||
# via the transformers ExpertsInterface.
|
||||
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
|
||||
# (no `mlp.` prefix, unlike Qwen/Mixtral).
|
||||
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
|
||||
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
|
||||
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
|
||||
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
|
||||
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
|
||||
|
||||
base_model: google/gemma-4-26B-A4B
|
||||
|
||||
@@ -24,7 +17,7 @@ plugins:
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
experts_implementation: scattermoe
|
||||
torch_compile: false
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
@@ -54,12 +47,9 @@ lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders).
|
||||
# lora_target_modules is intentionally empty — all module targeting is done
|
||||
# via regex in lora_target_linear_modules below.
|
||||
lora_target_modules: []
|
||||
lora_target_linear_modules:
|
||||
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
|
||||
lora_target_parameters:
|
||||
@@ -73,7 +63,7 @@ lora_o_kernel: false
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project: gemma4-qlora
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
@@ -93,8 +83,7 @@ gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
|
||||
flash_attention: false
|
||||
# FA2 not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
71
examples/gemma4/31b-qlora-flex.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: true
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora-flex
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
flex_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
69
examples/gemma4/31b-qlora.yaml
Normal file
69
examples/gemma4/31b-qlora.yaml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: google/gemma-4-31B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
torch_compile: false
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
strict: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:10%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma4-31b-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
|
||||
# Restrict LoRA to text backbone only (skip vision/audio encoders)
|
||||
# using regex to match only the text decoder attention projections.
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
bnb_config_kwargs:
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
60
examples/gemma4/README.md
Normal file
60
examples/gemma4/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Finetune Google's Gemma 4 with Axolotl
|
||||
|
||||
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
|
||||
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora.yaml
|
||||
|
||||
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
|
||||
axolotl train examples/gemma4/31b-qlora-flex.yaml
|
||||
```
|
||||
|
||||
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
|
||||
|
||||
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||
|
||||
## Flex Attention
|
||||
|
||||
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
|
||||
|
||||
```yaml
|
||||
torch_compile: true
|
||||
flex_attention: true
|
||||
```
|
||||
|
||||
This works for both the MoE and Dense model.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
|
||||
- **LoRA kernels**: Not supported due to KV-sharing layers.
|
||||
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
62
examples/gemma4/e2b-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
# Gemma 4 E2B Vision LoRA
|
||||
#
|
||||
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
|
||||
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
|
||||
|
||||
base_model: google/gemma-4-E2B-it
|
||||
processor_type: AutoProcessor
|
||||
freeze_mm_modules: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
# Required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: gemma4
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:100]
|
||||
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gemma4-e2b-vision-lora
|
||||
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
# Target language model only — vision encoder is frozen via freeze_mm_modules
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
@@ -14,8 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
@@ -87,7 +86,7 @@ for more information about using a special vllm-openai docker image for inferenc
|
||||
Optionally, vLLM can be installed from nightly:
|
||||
|
||||
```bash
|
||||
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
```
|
||||
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
||||
```bash
|
||||
|
||||
@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,8 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl.
|
||||
2. Install `timm` for vision model support:
|
||||
|
||||
```bash
|
||||
pip install timm==1.0.19
|
||||
uv pip install timm==1.0.19
|
||||
```
|
||||
|
||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
@@ -14,8 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -23,7 +23,7 @@ Note: This is still experimental given it is based on transformers v5 RC.
|
||||
git checkout transformers-v5
|
||||
|
||||
# Install packages for transformers v5
|
||||
pip install -e .
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
4. Run the fine-tuning:
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.6'
|
||||
uv pip install 'mistral-common[opencv]==1.8.6'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
|
||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
3. Install transformers from main
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
uv pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
@@ -22,8 +32,6 @@ dataset_prepared_path: last_run_prepared
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
@@ -31,16 +39,16 @@ lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
# Attention projection layers (present in ~12 attention layers out of 88)
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_layer_norm: true
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm_gated: true
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
@@ -23,8 +33,6 @@ dataset_prepared_path: last_run_prepared
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
@@ -36,11 +44,12 @@ lora_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -12,7 +12,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
@@ -26,8 +26,8 @@ sample_packing: true
|
||||
|
||||
# Freeze vision encoder
|
||||
unfrozen_parameters:
|
||||
- model\.language_model\..*
|
||||
- lm_head\..*
|
||||
- model.language_model.*
|
||||
- lm_head.*
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
62
examples/qwen3.5/35b-a3b-moe-vision-lora.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
# Qwen 3.5 35B-A3B MoE Vision LoRA
|
||||
#
|
||||
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
|
||||
# 256 experts, 8 active per token, with early-fusion vision support.
|
||||
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:100]
|
||||
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
|
||||
|
||||
adapter: lora
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,14 +13,13 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
```bash
|
||||
pip3 install num2words==0.5.14
|
||||
uv pip install num2words==0.5.14
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
@@ -12,16 +12,15 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
```bash
|
||||
# audio
|
||||
pip3 install librosa==0.11.0
|
||||
pip3 install 'mistral_common[audio]==1.8.3'
|
||||
uv pip install librosa==0.11.0
|
||||
uv pip install 'mistral_common[audio]==1.8.3'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
203
pyproject.toml
203
pyproject.toml
@@ -1,15 +1,165 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "axolotl"
|
||||
dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||
dynamic = ["version"]
|
||||
description = "LLM Trainer"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
# license = "Apache-2.0"
|
||||
|
||||
dependencies = [
|
||||
# Core ML stack
|
||||
"torch>=2.6.0",
|
||||
"packaging==26.0",
|
||||
"huggingface_hub>=1.1.7",
|
||||
"peft>=0.19.1,<0.20.0",
|
||||
"tokenizers>=0.22.1",
|
||||
"transformers==5.5.4",
|
||||
"accelerate==1.13.0",
|
||||
"datasets>=4.8.4,<4.9.0",
|
||||
"trl==1.1.0",
|
||||
"hf_xet==1.4.3",
|
||||
"kernels==0.13.0",
|
||||
"trackio>=0.16.1",
|
||||
"typing-extensions>=4.15.0",
|
||||
"optimum==1.16.2",
|
||||
"hf_transfer",
|
||||
"sentencepiece",
|
||||
"gradio>=6.2.0,<7.0",
|
||||
"modal==1.3.0.post1",
|
||||
"pydantic>=2.10.6",
|
||||
"addict",
|
||||
"fire",
|
||||
"PyYAML>=6.0",
|
||||
"requests",
|
||||
"wandb",
|
||||
"einops",
|
||||
"colorama",
|
||||
"numba>=0.61.2",
|
||||
"numpy>=2.2.6",
|
||||
|
||||
# Evaluation & metrics
|
||||
"evaluate==0.4.1",
|
||||
"scipy",
|
||||
"nvidia-ml-py==12.560.30",
|
||||
"art",
|
||||
"tensorboard",
|
||||
"python-dotenv==1.0.1",
|
||||
|
||||
# Remote filesystems
|
||||
"s3fs>=2024.5.0",
|
||||
"gcsfs>=2025.3.0",
|
||||
"adlfs>=2024.5.0",
|
||||
"ocifs==1.3.2",
|
||||
|
||||
"zstandard==0.22.0",
|
||||
"fastcore",
|
||||
|
||||
# lm eval harness
|
||||
"lm_eval==0.4.11",
|
||||
"langdetect==1.0.9",
|
||||
"immutabledict==4.2.0",
|
||||
"antlr4-python3-runtime==4.13.2",
|
||||
|
||||
"schedulefree==1.4.1",
|
||||
"openenv-core==0.1.0",
|
||||
|
||||
# Axolotl contribs
|
||||
"axolotl-contribs-lgpl==0.0.7",
|
||||
"axolotl-contribs-mit==0.0.6",
|
||||
|
||||
# Telemetry
|
||||
"posthog==6.7.11",
|
||||
|
||||
"mistral-common==1.11.0",
|
||||
|
||||
# Platform-specific (Linux only)
|
||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||
|
||||
# Architecture-specific
|
||||
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
|
||||
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
flash-attn = ["flash-attn==2.8.3"]
|
||||
ring-flash-attn = [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
]
|
||||
deepspeed = [
|
||||
"deepspeed>=0.18.6,<0.19.0",
|
||||
"deepspeed-kernels",
|
||||
]
|
||||
mamba-ssm = [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
]
|
||||
auto-gptq = [
|
||||
"auto-gptq==0.5.1",
|
||||
]
|
||||
mlflow = [
|
||||
"mlflow",
|
||||
]
|
||||
galore = [
|
||||
"galore_torch",
|
||||
]
|
||||
apollo = [
|
||||
"apollo-torch",
|
||||
]
|
||||
optimizers = [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
]
|
||||
ray = [
|
||||
"ray[train]>=2.52.1",
|
||||
]
|
||||
vllm = [
|
||||
"vllm>=0.15.0",
|
||||
]
|
||||
llmcompressor = [
|
||||
"llmcompressor>=0.10.0",
|
||||
]
|
||||
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
|
||||
opentelemetry = [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"black",
|
||||
"mypy",
|
||||
"pre-commit",
|
||||
"types-requests",
|
||||
"quartodoc",
|
||||
"jupyter",
|
||||
"blobfile",
|
||||
"tiktoken",
|
||||
]
|
||||
test = [
|
||||
"codecov",
|
||||
"codecov-cli",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-retry",
|
||||
"pytest-sugar",
|
||||
"pytest-xdist",
|
||||
"tbparse",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
axolotl = "axolotl.cli.main:main"
|
||||
|
||||
@@ -18,18 +168,15 @@ Homepage = "https://axolotl.ai/"
|
||||
Documentation = "https://docs.axolotl.ai/"
|
||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { file = "VERSION" }
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py310"
|
||||
@@ -67,5 +214,43 @@ markers = [
|
||||
"slow: marks tests as slow",
|
||||
]
|
||||
|
||||
# UV specific configuration
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
conflicts = [
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "vllm" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "flash-attn" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "ring-flash-attn" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "mamba-ssm" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "auto-gptq" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "fbgemm-gpu" },
|
||||
],
|
||||
[
|
||||
{ package = "axolotl" },
|
||||
{ extra = "llmcompressor" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
mamba-ssm = [{ requirement = "torch", match-runtime = true }]
|
||||
causal-conv1d = [{ requirement = "torch", match-runtime = true }]
|
||||
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||
auto-gptq = [{ requirement = "torch", match-runtime = true }]
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
black
|
||||
mypy
|
||||
pre-commit
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
blobfile
|
||||
tiktoken
|
||||
@@ -1,8 +0,0 @@
|
||||
codecov
|
||||
codecov-cli
|
||||
pytest
|
||||
pytest-cov
|
||||
pytest-retry
|
||||
pytest-sugar
|
||||
pytest-xdist
|
||||
tbparse
|
||||
@@ -1,78 +0,0 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.49.1
|
||||
triton>=3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.7.0
|
||||
# END section
|
||||
|
||||
packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.5.0
|
||||
accelerate==1.13.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
trl==0.29.0
|
||||
hf_xet==1.3.2
|
||||
kernels==0.12.2
|
||||
|
||||
fla-core==0.4.1
|
||||
flash-linear-attention==0.4.1
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio>=6.2.0,<7.0
|
||||
|
||||
modal==1.3.0.post1
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
wandb
|
||||
einops
|
||||
colorama
|
||||
numba>=0.61.2
|
||||
numpy>=2.2.6
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
nvidia-ml-py==12.560.30
|
||||
art
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2025.3.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.11
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.17.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.11.0
|
||||
1518
scripts/analyze_profile.py
Normal file
1518
scripts/analyze_profile.py
Normal file
File diff suppressed because it is too large
Load Diff
479
scripts/build_scattermoe_lora_kernel.py
Normal file
479
scripts/build_scattermoe_lora_kernel.py
Normal file
@@ -0,0 +1,479 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a disposable Hugging Face Kernel Hub package for ScatterMoE LoRA.
|
||||
|
||||
This script does not move or edit the in-tree Axolotl kernel sources. It copies
|
||||
``src/axolotl/integrations/kernels/libs/scattermoe_lora`` into an ignored
|
||||
build directory and emits a universal HF kernels project that can be pushed to
|
||||
the Hub.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
|
||||
PACKAGE_NAME = "scattermoe_lora"
|
||||
BUILD_VARIANT = "torch-universal"
|
||||
DEFAULT_REPO_ID = "kernels-community/scattermoe-lora"
|
||||
HF_REPO_TYPE = "kernel"
|
||||
HF_KERNEL_URL_PREFIX = "https://hf.co/kernels"
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_SOURCE_DIR = (
|
||||
REPO_ROOT / "src" / "axolotl" / "integrations" / "kernels" / "libs" / PACKAGE_NAME
|
||||
)
|
||||
DEFAULT_OUTPUT_DIR = REPO_ROOT / "build" / "hf-kernels" / PACKAGE_NAME
|
||||
|
||||
EXCLUDED_DIRS = {
|
||||
"__pycache__",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
".ruff_cache",
|
||||
}
|
||||
EXCLUDED_FILE_PATTERNS = {
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
"*.so",
|
||||
".DS_Store",
|
||||
}
|
||||
|
||||
TEXT_REPLACEMENTS = {
|
||||
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import": (
|
||||
"from .selective_dequant import"
|
||||
),
|
||||
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import": (
|
||||
"from .selective_dequant_kernel import"
|
||||
),
|
||||
"from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import": (
|
||||
"from .ops import"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Copy Axolotl's ScatterMoE LoRA Triton kernels into a disposable "
|
||||
"HF Kernel Hub universal package."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source-dir",
|
||||
type=Path,
|
||||
default=DEFAULT_SOURCE_DIR,
|
||||
help=f"ScatterMoE LoRA source package to copy. Default: {DEFAULT_SOURCE_DIR}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=DEFAULT_OUTPUT_DIR,
|
||||
help=f"Destination build/dist directory. Default: {DEFAULT_OUTPUT_DIR}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
default=DEFAULT_REPO_ID,
|
||||
help=f"HF Hub repo id to write into build.toml. Default: {DEFAULT_REPO_ID}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Kernel major version written to build.toml and metadata.json.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Delete the output directory first if it already exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-source-layout",
|
||||
action="store_true",
|
||||
help="Only write the shippable build/ tree, not torch-ext/ sources.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Upload the generated universal kernel package with huggingface_hub. "
|
||||
"This bypasses kernel-builder and is intended for pure Python/Triton "
|
||||
"universal kernels."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Create the HF Hub repo as private when used with --upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-version-branch",
|
||||
action="store_true",
|
||||
help="With --upload, only upload main and skip the v<version> branch.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def should_skip_file(path: Path) -> bool:
|
||||
return any(
|
||||
fnmatch.fnmatch(path.name, pattern) for pattern in EXCLUDED_FILE_PATTERNS
|
||||
)
|
||||
|
||||
|
||||
def iter_source_files(source_dir: Path) -> list[Path]:
|
||||
files: list[Path] = []
|
||||
for root, dirs, filenames in os.walk(source_dir):
|
||||
dirs[:] = sorted(d for d in dirs if d not in EXCLUDED_DIRS)
|
||||
for filename in sorted(filenames):
|
||||
path = Path(root) / filename
|
||||
if not should_skip_file(path):
|
||||
files.append(path)
|
||||
return files
|
||||
|
||||
|
||||
def content_hash(source_dir: Path) -> str:
|
||||
digest = hashlib.sha1()
|
||||
for path in iter_source_files(source_dir):
|
||||
rel = path.relative_to(source_dir).as_posix()
|
||||
digest.update(rel.encode("utf-8"))
|
||||
digest.update(b"\0")
|
||||
digest.update(path.read_bytes())
|
||||
digest.update(b"\0")
|
||||
return digest.hexdigest()[:10]
|
||||
|
||||
|
||||
def git_revision() -> str:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
cwd=REPO_ROOT,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except (OSError, subprocess.CalledProcessError):
|
||||
return "unknown"
|
||||
return result.stdout.strip() or "unknown"
|
||||
|
||||
|
||||
def transform_python_source(text: str, rel_path: Path, op_namespace: str) -> str:
|
||||
for old, new in TEXT_REPLACEMENTS.items():
|
||||
text = text.replace(old, new)
|
||||
|
||||
if rel_path.as_posix() == "gemma4_experts.py":
|
||||
text = text.replace(
|
||||
" from axolotl.integrations.kernels.constants import resolve_experts_class",
|
||||
(
|
||||
" raise RuntimeError(\n"
|
||||
' "patch_gemma4_scattermoe is only available from the in-tree Axolotl "\n'
|
||||
' "integration. Use register_scattermoe_experts() with the standalone "\n'
|
||||
' "HF kernel package."\n'
|
||||
" )"
|
||||
),
|
||||
)
|
||||
|
||||
return text.replace("scattermoe::", f"{op_namespace}::")
|
||||
|
||||
|
||||
def copy_package(source_dir: Path, package_dir: Path, op_namespace: str) -> None:
|
||||
for source in iter_source_files(source_dir):
|
||||
rel_path = source.relative_to(source_dir)
|
||||
destination = package_dir / rel_path
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if source.suffix == ".py":
|
||||
text = source.read_text(encoding="utf-8")
|
||||
text = transform_python_source(text, rel_path, op_namespace)
|
||||
destination.write_text(text, encoding="utf-8")
|
||||
else:
|
||||
shutil.copy2(source, destination)
|
||||
|
||||
write_ops_module(package_dir / "_ops.py", op_namespace)
|
||||
|
||||
|
||||
def write_ops_module(path: Path, op_namespace: str) -> None:
|
||||
path.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"import torch",
|
||||
"",
|
||||
f"ops = torch.ops.{op_namespace}",
|
||||
"",
|
||||
"",
|
||||
"def add_op_namespace_prefix(op_name: str) -> str:",
|
||||
f' return f"{op_namespace}::{{op_name}}"',
|
||||
"",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def write_build_toml(path: Path, repo_id: str, version: int) -> None:
|
||||
lines = [
|
||||
"[general]",
|
||||
f'name = "{PACKAGE_NAME}"',
|
||||
"universal = true",
|
||||
f"version = {version}",
|
||||
"",
|
||||
]
|
||||
if repo_id:
|
||||
lines.extend(
|
||||
[
|
||||
"[general.hub]",
|
||||
f'repo-id = "{repo_id}"',
|
||||
"",
|
||||
]
|
||||
)
|
||||
path.write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
|
||||
def write_flake(path: Path) -> None:
|
||||
path.write_text(
|
||||
"""{
|
||||
description = "Flake for scattermoe_lora kernel";
|
||||
|
||||
inputs = {
|
||||
builder.url = "github:huggingface/kernels";
|
||||
};
|
||||
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
builder,
|
||||
}:
|
||||
builder.lib.genKernelFlakeOutputs {
|
||||
inherit self;
|
||||
path = ./.;
|
||||
};
|
||||
}
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def write_readme(path: Path, repo_id: str, source_hash: str, op_namespace: str) -> None:
|
||||
repo_display = repo_id or "<your-org>/scattermoe-lora"
|
||||
path.write_text(
|
||||
f"""---
|
||||
library_name: kernels
|
||||
license: apache-2.0
|
||||
tags:
|
||||
- kernel
|
||||
- kernels
|
||||
---
|
||||
|
||||
# ScatterMoE LoRA
|
||||
|
||||
Standalone Hugging Face Kernel Hub package for Axolotl's ScatterMoE LoRA Triton kernels.
|
||||
|
||||
This package is generated from Axolotl's in-tree `scattermoe_lora` sources and is exported as a universal kernel because the implementation is Python/Triton rather than a precompiled C++/CUDA extension.
|
||||
|
||||
```python
|
||||
from kernels import get_kernel
|
||||
|
||||
scattermoe_lora = get_kernel("{repo_display}")
|
||||
```
|
||||
|
||||
Export metadata:
|
||||
|
||||
- source package: `src/axolotl/integrations/kernels/libs/scattermoe_lora`
|
||||
- source revision: `{git_revision()}`
|
||||
- source content hash: `{source_hash}`
|
||||
- torch custom op namespace: `{op_namespace}`
|
||||
|
||||
The generated `build/torch-universal/{PACKAGE_NAME}` directory is the shippable Hub artifact. `torch-ext/{PACKAGE_NAME}` is included so `kernel-builder build-and-copy` can regenerate the universal build tree if desired.
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def write_metadata(path: Path, version: int) -> None:
|
||||
path.write_text(
|
||||
json.dumps({"version": version}, indent=2, sort_keys=True) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def prepare_output_dir(output_dir: Path, force: bool) -> None:
|
||||
if output_dir.exists():
|
||||
if not force:
|
||||
raise FileExistsError(
|
||||
f"{output_dir} already exists. Re-run with --force to replace it."
|
||||
)
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
|
||||
def build_package(args: argparse.Namespace) -> Path:
|
||||
source_dir = args.source_dir.resolve()
|
||||
output_dir = args.output_dir.resolve()
|
||||
|
||||
if not source_dir.is_dir():
|
||||
raise FileNotFoundError(f"source package does not exist: {source_dir}")
|
||||
if not (source_dir / "__init__.py").is_file():
|
||||
raise FileNotFoundError(f"source package is missing __init__.py: {source_dir}")
|
||||
|
||||
source_hash = content_hash(source_dir)
|
||||
op_namespace = f"_{PACKAGE_NAME}_{source_hash}"
|
||||
|
||||
prepare_output_dir(output_dir, args.force)
|
||||
|
||||
write_build_toml(output_dir / "build.toml", args.repo_id, args.version)
|
||||
write_flake(output_dir / "flake.nix")
|
||||
write_readme(output_dir / "README.md", args.repo_id, source_hash, op_namespace)
|
||||
|
||||
if not args.no_source_layout:
|
||||
copy_package(source_dir, output_dir / "torch-ext" / PACKAGE_NAME, op_namespace)
|
||||
|
||||
build_package_dir = output_dir / "build" / BUILD_VARIANT / PACKAGE_NAME
|
||||
copy_package(source_dir, build_package_dir, op_namespace)
|
||||
write_metadata(build_package_dir.parent / "metadata.json", args.version)
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
def upload_package(args: argparse.Namespace, output_dir: Path) -> None:
|
||||
if not args.repo_id:
|
||||
raise ValueError("--repo-id is required when using --upload")
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi, constants as hf_constants
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"--upload requires huggingface_hub. Install it or run the upload "
|
||||
"manually with the Hugging Face CLI."
|
||||
) from exc
|
||||
|
||||
try:
|
||||
hub_version = metadata.version("huggingface_hub")
|
||||
except metadata.PackageNotFoundError:
|
||||
hub_version = "unknown"
|
||||
|
||||
accepted_repo_types = getattr(
|
||||
hf_constants,
|
||||
"REPO_TYPES_WITH_KERNEL",
|
||||
getattr(hf_constants, "REPO_TYPES", ()),
|
||||
)
|
||||
if HF_REPO_TYPE not in accepted_repo_types:
|
||||
raise RuntimeError(
|
||||
"Your huggingface_hub installation does not support "
|
||||
f"repo_type={HF_REPO_TYPE!r} (found huggingface_hub {hub_version}). "
|
||||
f"Upgrade this interpreter with: {sys.executable} -m pip install --upgrade "
|
||||
"'huggingface_hub>=1.10.0'"
|
||||
)
|
||||
|
||||
# huggingface_hub 1.11.0 has partial kernel support: create_repo accepts
|
||||
# "kernel", but upload_folder/create_commit still validate against the
|
||||
# older REPO_TYPES list. Extend it in-process so those helpers use the
|
||||
# /api/kernels/... endpoints until upstream broadens that check.
|
||||
if HF_REPO_TYPE not in hf_constants.REPO_TYPES:
|
||||
hf_constants.REPO_TYPES.append(HF_REPO_TYPE)
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
repo_id = api.create_repo(
|
||||
repo_id=args.repo_id,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
private=args.private,
|
||||
exist_ok=True,
|
||||
).repo_id
|
||||
except ValueError as exc:
|
||||
if "Invalid repo type" in str(exc):
|
||||
raise RuntimeError(
|
||||
"huggingface_hub rejected repo_type='kernel'. "
|
||||
f"This usually means the command is running with an older Hub "
|
||||
f"client than expected (found huggingface_hub {hub_version} at "
|
||||
f"{sys.executable}). Upgrade with: {sys.executable} -m pip "
|
||||
"install --upgrade 'huggingface_hub>=1.10.0'"
|
||||
) from exc
|
||||
raise
|
||||
|
||||
delete_patterns = [
|
||||
"build/**",
|
||||
"torch-ext/**",
|
||||
"build.toml",
|
||||
"flake.nix",
|
||||
"README.md",
|
||||
]
|
||||
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
folder_path=output_dir,
|
||||
revision="main",
|
||||
delete_patterns=delete_patterns,
|
||||
commit_message="Upload ScatterMoE LoRA universal kernel",
|
||||
)
|
||||
print(f"Uploaded main branch: {HF_KERNEL_URL_PREFIX}/{repo_id}")
|
||||
|
||||
if args.skip_version_branch:
|
||||
return
|
||||
|
||||
version_branch = f"v{args.version}"
|
||||
api.create_branch(
|
||||
repo_id=repo_id,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
branch=version_branch,
|
||||
revision="main",
|
||||
exist_ok=True,
|
||||
)
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=HF_REPO_TYPE,
|
||||
folder_path=output_dir,
|
||||
revision=version_branch,
|
||||
delete_patterns=delete_patterns,
|
||||
commit_message=f"Upload ScatterMoE LoRA universal kernel {version_branch}",
|
||||
)
|
||||
print(
|
||||
f"Uploaded version branch: "
|
||||
f"{HF_KERNEL_URL_PREFIX}/{repo_id}/tree/{version_branch}"
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
try:
|
||||
output_dir = build_package(args)
|
||||
if args.upload:
|
||||
upload_package(args, output_dir)
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"Wrote ScatterMoE LoRA HF kernel package to: {output_dir}")
|
||||
print(f"Shippable artifact: {output_dir / 'build' / BUILD_VARIANT / PACKAGE_NAME}")
|
||||
if args.upload:
|
||||
print(f'Load it with: get_kernel("{args.repo_id}", version={args.version})')
|
||||
print(f"Uploaded as Hugging Face repo_type={HF_REPO_TYPE!r}.")
|
||||
return 0
|
||||
|
||||
print("Next step:")
|
||||
print(" upload this universal Python/Triton kernel directly:")
|
||||
print(
|
||||
f" python3 {Path(__file__).as_posix()} "
|
||||
f"--repo-id {args.repo_id} --force --upload"
|
||||
)
|
||||
if shutil.which("kernel-builder") is None:
|
||||
print(" optional: install kernel-builder for full Nix-based builds:")
|
||||
print(
|
||||
" curl -fsSL "
|
||||
"https://raw.githubusercontent.com/huggingface/kernels/main/install.sh "
|
||||
"| bash"
|
||||
)
|
||||
else:
|
||||
print(" optional: upload with kernel-builder:")
|
||||
print(f" cd {output_dir}")
|
||||
print(" kernel-builder build-and-upload")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
|
||||
)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# noqa
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as error:
|
||||
raise ImportError("Install torch via `pip install torch`") from error
|
||||
from packaging.version import Version as V
|
||||
|
||||
use_uv = "--uv" in sys.argv[1:]
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
try:
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere = False
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
raise RuntimeError(f"Torch = {v} too old!")
|
||||
elif v <= V("2.1.1"):
|
||||
x = "cu{}{}-torch211"
|
||||
elif v <= V("2.1.2"):
|
||||
x = "cu{}{}-torch212"
|
||||
elif v < V("2.3.0"):
|
||||
x = "cu{}{}-torch220"
|
||||
elif v < V("2.4.0"):
|
||||
x = "cu{}{}-torch230"
|
||||
elif v < V("2.5.0"):
|
||||
x = "cu{}{}-torch240"
|
||||
elif v < V("2.6.0"):
|
||||
x = "cu{}{}-torch250"
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
uv_prefix = "uv " if use_uv else ""
|
||||
print(
|
||||
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||
)
|
||||
230
setup.py
230
setup.py
@@ -1,230 +0,0 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements(extras_require_map):
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
install_xformers = platform.machine() != "aarch64"
|
||||
if platform.machine() == "aarch64":
|
||||
# skip on ARM64
|
||||
skip_packages = [
|
||||
"torchao",
|
||||
"fla-core",
|
||||
"flash-linear-attention",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
print(
|
||||
_install_requires, [req in skip_packages for req in _install_requires]
|
||||
)
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
torch_parts = torch_version.split("+")
|
||||
if len(torch_parts) == 2:
|
||||
torch_cuda_version = torch_parts[1]
|
||||
_dependency_links.append(
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 10):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.5.0",
|
||||
"fbgemm-gpu-genai==1.5.0",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm>=0.17.1"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
extras_require_map["vllm"] = ["vllm==0.14.0"]
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
# vllm 0.9.x is incompatible with latest transformers
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
else:
|
||||
raise ValueError("axolotl requires torch>=2.4")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links, extras_require_map
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(
|
||||
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
version_ = fin.read().strip()
|
||||
return version_
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.8.3"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.18.2",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]>=2.52.1",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.10.0",
|
||||
],
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
)
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"axolotl=axolotl.cli.main:main",
|
||||
],
|
||||
},
|
||||
extras_require=extras_require_build,
|
||||
)
|
||||
108
src/axolotl/cli/agent_docs/__init__.py
Normal file
108
src/axolotl/cli/agent_docs/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Bundled agent documentation for axolotl.
|
||||
|
||||
These docs are optimized for consumption by AI coding agents.
|
||||
The source of truth is docs/agents/*.md and AGENTS.md in the repo root.
|
||||
This module resolves those paths at runtime — no files are duplicated
|
||||
into the package.
|
||||
|
||||
For pip-only installs (no repo checkout), run `axolotl fetch docs` first
|
||||
to download the docs locally.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# Topic name -> (filename in docs/agents/, fallback filename for AGENTS.md)
|
||||
TOPICS = {
|
||||
"overview": "AGENTS.md",
|
||||
"sft": "docs/agents/sft.md",
|
||||
"grpo": "docs/agents/grpo.md",
|
||||
"preference_tuning": "docs/agents/preference_tuning.md",
|
||||
"reward_modelling": "docs/agents/reward_modelling.md",
|
||||
"pretraining": "docs/agents/pretraining.md",
|
||||
"model_architectures": "docs/agents/model_architectures.md",
|
||||
"new_model_support": "docs/agents/new_model_support.md",
|
||||
}
|
||||
|
||||
|
||||
def _find_repo_root() -> Path | None:
|
||||
"""Walk up from this file to find the repo root (contains AGENTS.md)."""
|
||||
# In an editable install or repo checkout, walk up from
|
||||
# src/axolotl/cli/agent_docs/ to find the repo root
|
||||
current = Path(__file__).resolve().parent
|
||||
while current != current.parent:
|
||||
if (current / "AGENTS.md").exists() and (current / "docs" / "agents").is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
|
||||
def _find_docs_dir() -> Path | None:
|
||||
"""Find a fetched docs directory (from `axolotl fetch docs`)."""
|
||||
# axolotl fetch docs --dest defaults to ./docs/ in cwd
|
||||
cwd_docs = Path.cwd() / "docs" / "agents"
|
||||
if cwd_docs.is_dir():
|
||||
return Path.cwd()
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_path(topic: str) -> Path:
|
||||
"""Resolve a topic name to the actual file path."""
|
||||
if topic not in TOPICS:
|
||||
available = ", ".join(sorted(TOPICS.keys()))
|
||||
raise FileNotFoundError(f"Unknown topic: {topic!r}. Available: {available}")
|
||||
|
||||
relative_path = TOPICS[topic]
|
||||
|
||||
# Try repo root first (editable install / repo checkout)
|
||||
repo_root = _find_repo_root()
|
||||
if repo_root:
|
||||
candidate = repo_root / relative_path
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
# Try cwd (fetched docs via `axolotl fetch docs`)
|
||||
docs_root = _find_docs_dir()
|
||||
if docs_root:
|
||||
candidate = docs_root / relative_path
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
# Also check cwd directly for AGENTS.md
|
||||
if topic == "overview":
|
||||
cwd_agents = Path.cwd() / "AGENTS.md"
|
||||
if cwd_agents.exists():
|
||||
return cwd_agents
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find {relative_path!r}. "
|
||||
f"If you installed axolotl via pip, run `axolotl fetch docs` first "
|
||||
f"to download the documentation."
|
||||
)
|
||||
|
||||
|
||||
def get_doc(topic: str = "overview") -> str:
|
||||
"""Return the content of an agent doc by topic name.
|
||||
|
||||
Args:
|
||||
topic: One of the keys in TOPICS, or "overview" (default).
|
||||
|
||||
Returns:
|
||||
The markdown content of the doc.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the topic can't be found.
|
||||
"""
|
||||
return _resolve_path(topic).read_text()
|
||||
|
||||
|
||||
def list_topics() -> dict[str, str]:
|
||||
"""Return a dict of topic name -> first line (title) of each doc."""
|
||||
result = {}
|
||||
for topic in sorted(TOPICS.keys()):
|
||||
try:
|
||||
path = _resolve_path(topic)
|
||||
first_line = path.read_text().split("\n", 1)[0].lstrip("# ").strip()
|
||||
result[topic] = first_line
|
||||
except FileNotFoundError:
|
||||
result[topic] = "(not found — run `axolotl fetch docs`)"
|
||||
return result
|
||||
@@ -294,7 +294,9 @@ def merge_lora(config: str, **kwargs):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.argument(
|
||||
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
|
||||
)
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
@@ -303,9 +305,10 @@ def fetch(directory: str, dest: Optional[str]):
|
||||
Available directories:
|
||||
- examples: Example configuration files
|
||||
- deepspeed_configs: DeepSpeed configuration files
|
||||
- docs: Full documentation (Quarto markdown files)
|
||||
|
||||
Args:
|
||||
directory: One of `examples`, `deepspeed_configs`.
|
||||
directory: One of `examples`, `deepspeed_configs`, `docs`.
|
||||
dest: Optional destination directory.
|
||||
"""
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
@@ -340,6 +343,112 @@ def delinearize_llama4(model: str, output: str):
|
||||
do_delinearize_llama4(model, output)
|
||||
|
||||
|
||||
@cli.command("agent-docs")
|
||||
@click.argument("topic", required=False, default=None)
|
||||
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
|
||||
def agent_docs(topic: Optional[str], list_topics: bool):
|
||||
"""Show agent-optimized documentation.
|
||||
|
||||
Prints reference docs designed for AI coding agents.
|
||||
These docs are bundled with the package — no network access needed.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
axolotl agent-docs # overview (start here)
|
||||
axolotl agent-docs grpo # GRPO reference
|
||||
axolotl agent-docs sft # SFT reference
|
||||
axolotl agent-docs --list # list all topics
|
||||
"""
|
||||
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
|
||||
|
||||
if list_topics:
|
||||
for name, title in _list_topics().items():
|
||||
click.echo(f" {name:25s} {title}")
|
||||
return
|
||||
|
||||
if topic is None:
|
||||
topic = "overview"
|
||||
|
||||
try:
|
||||
click.echo(get_doc(topic))
|
||||
except FileNotFoundError as exc:
|
||||
raise click.BadParameter(str(exc)) from exc
|
||||
|
||||
|
||||
@cli.command("config-schema")
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["json", "yaml"]),
|
||||
default="json",
|
||||
help="Output format (default: json)",
|
||||
)
|
||||
@click.option("--field", help="Show schema for a specific field only")
|
||||
def config_schema(output_format: str, field: Optional[str]):
|
||||
"""Dump the full config JSON schema.
|
||||
|
||||
Useful for AI agents and tooling to discover all available config options,
|
||||
their types, defaults, and descriptions.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
axolotl config-schema # full JSON schema
|
||||
axolotl config-schema --format yaml # YAML format
|
||||
axolotl config-schema --field adapter # single field
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
schema = AxolotlInputConfig.model_json_schema()
|
||||
except (TypeError, ValueError, AttributeError) as exc:
|
||||
# Fallback: dump field names, types, and defaults when full schema
|
||||
# generation fails (e.g. torch.dtype not JSON-serializable)
|
||||
LOG.warning(
|
||||
"Full JSON schema generation failed, using simplified fallback: %s", exc
|
||||
)
|
||||
fields = {}
|
||||
for name, field_info in AxolotlInputConfig.model_fields.items():
|
||||
entry = {}
|
||||
if field_info.description:
|
||||
entry["description"] = field_info.description
|
||||
if field_info.default is not None:
|
||||
try:
|
||||
json.dumps(field_info.default)
|
||||
entry["default"] = field_info.default
|
||||
except (TypeError, ValueError):
|
||||
entry["default"] = str(field_info.default)
|
||||
annotation = field_info.annotation
|
||||
if annotation is not None:
|
||||
entry["type"] = str(annotation)
|
||||
fields[name] = entry
|
||||
schema = {
|
||||
"properties": fields,
|
||||
"_note": "simplified schema (full generation failed)",
|
||||
}
|
||||
|
||||
if field:
|
||||
props = schema.get("properties", {})
|
||||
if field not in props:
|
||||
# Try case-insensitive match
|
||||
matches = [k for k in props if k.lower() == field.lower()]
|
||||
if matches:
|
||||
field = matches[0]
|
||||
else:
|
||||
raise click.BadParameter(
|
||||
f"Unknown field: {field!r}. "
|
||||
f"Omit --field to dump the full schema, "
|
||||
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
|
||||
)
|
||||
schema = {field: props[field]}
|
||||
|
||||
if output_format == "yaml":
|
||||
import yaml # pylint: disable=import-outside-toplevel
|
||||
|
||||
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
|
||||
else:
|
||||
click.echo(json.dumps(schema, indent=2))
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
|
||||
simulate_nf4_experts=simulate_nf4_experts,
|
||||
nf4_blocksize=nf4_blocksize,
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
|
||||
)
|
||||
|
||||
LOG.debug("Memory-efficient LoRA merge completed successfully!")
|
||||
|
||||
@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _build_layer_type_map(
|
||||
base_model_path: Path, trust_remote_code: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Build a map of module_name -> layer_type using a meta-device model.
|
||||
|
||||
Instantiates the model architecture on the meta device (zero memory)
|
||||
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
|
||||
This avoids relying on weight tensor ndim heuristics.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
config_path = base_model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
model_config = _json.load(f)
|
||||
except (OSError, _json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
architectures = model_config.get("architectures", [])
|
||||
if not architectures:
|
||||
return {}
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
str(base_model_path), trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
LOG.debug("Could not load config for layer type introspection")
|
||||
return {}
|
||||
|
||||
# Determine the right Auto class from architectures
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
|
||||
auto_classes = [AutoModelForCausalLM, AutoModel]
|
||||
try:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
auto_classes.insert(0, AutoModelForImageTextToText)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
model = None
|
||||
for auto_cls in auto_classes:
|
||||
try:
|
||||
with torch.device("meta"):
|
||||
model = auto_cls.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
break
|
||||
except Exception: # noqa: BLE001
|
||||
LOG.debug(
|
||||
"Could not instantiate meta model with %s, trying next",
|
||||
auto_cls.__name__,
|
||||
)
|
||||
|
||||
if model is None:
|
||||
LOG.debug("Could not instantiate meta model for layer type introspection")
|
||||
return {}
|
||||
|
||||
layer_types = {}
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Conv3d):
|
||||
layer_types[name] = "Conv3d"
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
layer_types[name] = "Conv2d"
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
layer_types[name] = "Conv1d"
|
||||
elif isinstance(module, nn.Linear):
|
||||
layer_types[name] = "Linear"
|
||||
|
||||
del model
|
||||
LOG.debug(
|
||||
f"Layer type map: {len(layer_types)} modules "
|
||||
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
|
||||
)
|
||||
return layer_types
|
||||
|
||||
|
||||
def _simulate_nf4_roundtrip(
|
||||
tensor: torch.Tensor,
|
||||
blocksize: Optional[int] = None,
|
||||
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
|
||||
adapter_name: str = "default",
|
||||
is_param_wrapper: bool = False,
|
||||
magnitude: Optional[torch.Tensor] = None,
|
||||
layer_type: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Use PEFT's own layer classes to compute the LoRA delta weight.
|
||||
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
|
||||
out_features = lora_b.shape[0]
|
||||
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
|
||||
use_rslora = bool(lora_config_dict.get("use_rslora", False))
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
if is_param_wrapper:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
@@ -227,18 +315,110 @@ def _build_peft_layer_and_get_delta(
|
||||
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
|
||||
)
|
||||
|
||||
# ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so
|
||||
# build a minimal config with only the fields it accepts.
|
||||
pw_config = LoraConfig(
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=0.0,
|
||||
fan_in_fan_out=False,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=False,
|
||||
lora_bias=False,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
layer = ParamWrapper(
|
||||
fake,
|
||||
adapter_name=adapter_name,
|
||||
parameter_name="weight",
|
||||
config=pw_config,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=use_rslora,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
delta = layer.get_delta_weight(adapter_name)
|
||||
# peft >=0.19.1 may return delta with transposed dims for 3D params
|
||||
if delta.shape != base_tensor.shape and delta.ndim == 3:
|
||||
delta = delta.transpose(1, 2).contiguous()
|
||||
return delta
|
||||
elif (
|
||||
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||
):
|
||||
# Conv layer detected via model introspection (or ndim fallback)
|
||||
|
||||
from peft.tuners.lora import layer as peft_lora_layer
|
||||
|
||||
# Determine conv type from layer_type map or fall back to ndim
|
||||
if layer_type and "Conv" in layer_type:
|
||||
conv_type: str = layer_type
|
||||
else:
|
||||
ndim = lora_a.ndim
|
||||
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
|
||||
if ndim not in _conv_map:
|
||||
raise ValueError(
|
||||
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
|
||||
)
|
||||
conv_type = _conv_map[ndim]
|
||||
LOG.warning(
|
||||
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
|
||||
f"Consider providing layer_type from meta-device introspection."
|
||||
)
|
||||
|
||||
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
|
||||
ConvCls = conv_cls_map[conv_type]
|
||||
PeftConvCls = getattr(peft_lora_layer, conv_type)
|
||||
|
||||
# Reconstruct conv parameters from base tensor and lora_a shapes
|
||||
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
|
||||
# lora_a: [r, in_channels/groups, *kernel_size]
|
||||
# lora_b: [out_channels, r, *ones]
|
||||
out_channels = base_tensor.shape[0]
|
||||
in_channels = base_tensor.shape[1]
|
||||
kernel_size = tuple(base_tensor.shape[2:])
|
||||
stride = (1,) * (base_tensor.ndim - 2)
|
||||
padding = (0,) * (base_tensor.ndim - 2)
|
||||
|
||||
base_layer = ConvCls(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
base_layer.weight.data = base_tensor.clone()
|
||||
|
||||
conv_config = LoraConfig(
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
layer = PeftConvCls(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
config=conv_config,
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
|
||||
if use_dora:
|
||||
if magnitude is None:
|
||||
raise ValueError(
|
||||
f"DoRA merge requires a magnitude vector but none was found "
|
||||
f"for conv layer (adapter={adapter_name}). Check that the "
|
||||
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||
)
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
return base_layer.weight.data - base_tensor
|
||||
|
||||
return layer.get_delta_weight(adapter_name)
|
||||
else:
|
||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||
@@ -251,15 +431,20 @@ def _build_peft_layer_and_get_delta(
|
||||
or lora_config_dict.get("lora_fan_in_fan_out", False)
|
||||
)
|
||||
|
||||
layer = LoraLinear(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
linear_config = LoraConfig(
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
layer = LoraLinear(
|
||||
base_layer,
|
||||
adapter_name=adapter_name,
|
||||
config=linear_config,
|
||||
r=r_total,
|
||||
lora_alpha=lora_alpha,
|
||||
)
|
||||
layer.lora_A[adapter_name].weight.data = lora_a
|
||||
layer.lora_B[adapter_name].weight.data = lora_b
|
||||
|
||||
@@ -267,6 +452,12 @@ def _build_peft_layer_and_get_delta(
|
||||
# DoRA merges magnitude normalization into the weight directly.
|
||||
# Use PEFT's merge() which handles DoRA internally, then
|
||||
# compute the delta as merged_weight - original_weight.
|
||||
if magnitude is None:
|
||||
raise ValueError(
|
||||
f"DoRA merge requires a magnitude vector but none was found "
|
||||
f"for linear layer (adapter={adapter_name}). Check that the "
|
||||
f"adapter checkpoint contains lora_magnitude_vector weights."
|
||||
)
|
||||
mag_layer = layer.lora_magnitude_vector[adapter_name]
|
||||
mag_layer.weight = nn.Parameter(magnitude)
|
||||
layer.merge(adapter_names=[adapter_name])
|
||||
@@ -382,6 +573,7 @@ def _merge_tensor_with_lora(
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
layer_type_map: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[torch.Tensor, bool]:
|
||||
"""
|
||||
Helper function to merge a single tensor with its corresponding LoRA weights.
|
||||
@@ -426,12 +618,30 @@ def _merge_tensor_with_lora(
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
|
||||
# Look up layer type from meta-device model introspection
|
||||
_layer_type = None
|
||||
if layer_type_map:
|
||||
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
|
||||
_layer_type = layer_type_map.get(mod_path)
|
||||
# Try common prefix variations (e.g. with/without "model." prefix)
|
||||
if _layer_type is None:
|
||||
for prefix in [
|
||||
"model.",
|
||||
"model.language_model.",
|
||||
"model.language_model.model.",
|
||||
]:
|
||||
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||
if _layer_type:
|
||||
break
|
||||
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
layer_type=_layer_type,
|
||||
)
|
||||
merged_tensor = (
|
||||
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
|
||||
@@ -556,6 +766,7 @@ def _fuse_and_unfuse_with_merge(
|
||||
nf4_double_quant: bool = True,
|
||||
use_dora: bool = False,
|
||||
weight_renamings: Optional[Dict[str, str]] = None,
|
||||
layer_type_map: Optional[Dict[str, str]] = None,
|
||||
) -> tuple[Dict[str, torch.Tensor], int, set]:
|
||||
"""
|
||||
For tensors matching WeightConverter patterns (MoE expert weights):
|
||||
@@ -696,12 +907,32 @@ def _fuse_and_unfuse_with_merge(
|
||||
if use_dora
|
||||
else None
|
||||
)
|
||||
# Look up layer type for the fused key
|
||||
_layer_type = None
|
||||
if layer_type_map:
|
||||
mod_path = (
|
||||
fused_key.rsplit(".weight", 1)[0]
|
||||
if fused_key.endswith(".weight")
|
||||
else fused_key
|
||||
)
|
||||
_layer_type = layer_type_map.get(mod_path)
|
||||
if _layer_type is None:
|
||||
for prefix in [
|
||||
"model.",
|
||||
"model.language_model.",
|
||||
"model.language_model.model.",
|
||||
]:
|
||||
_layer_type = layer_type_map.get(prefix + mod_path)
|
||||
if _layer_type:
|
||||
break
|
||||
|
||||
delta = _build_peft_layer_and_get_delta(
|
||||
lora_a.to(device),
|
||||
lora_b.to(device),
|
||||
lora_config_dict,
|
||||
fused_tensor.to(device),
|
||||
magnitude=magnitude.to(device) if magnitude is not None else None,
|
||||
layer_type=_layer_type,
|
||||
)
|
||||
fused_tensor = (
|
||||
(
|
||||
@@ -740,6 +971,7 @@ def merge_lora_sharded_efficient(
|
||||
simulate_nf4_experts: bool = False,
|
||||
nf4_blocksize: Optional[int] = None,
|
||||
nf4_double_quant: bool = True,
|
||||
trust_remote_code: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Memory-efficient LoRA merging that processes shards individually
|
||||
@@ -750,6 +982,8 @@ def merge_lora_sharded_efficient(
|
||||
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
|
||||
(for quantize_moe_experts). Expert tensors are identified by having
|
||||
"expert" in the key name and ndim >= 3.
|
||||
trust_remote_code: Whether to trust remote code when loading model
|
||||
config for layer-type introspection. Defaults to False for safety.
|
||||
"""
|
||||
base_model_path = Path(base_model_path)
|
||||
lora_adapter_path = Path(lora_adapter_path)
|
||||
@@ -780,6 +1014,10 @@ def merge_lora_sharded_efficient(
|
||||
|
||||
use_dora = bool(lora_config_dict.get("use_dora", False))
|
||||
|
||||
# Build layer type map via meta-device model introspection
|
||||
layer_type_map = _build_layer_type_map(
|
||||
base_model_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
unsupported_methods = []
|
||||
|
||||
# Check for AdaLoRA (Adaptive LoRA)
|
||||
@@ -904,6 +1142,7 @@ def merge_lora_sharded_efficient(
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
layer_type_map=layer_type_map,
|
||||
)
|
||||
merged_count += fused_merged
|
||||
|
||||
@@ -926,6 +1165,7 @@ def merge_lora_sharded_efficient(
|
||||
nf4_double_quant=nf4_double_quant,
|
||||
use_dora=use_dora,
|
||||
weight_renamings=weight_renamings,
|
||||
layer_type_map=layer_type_map,
|
||||
)
|
||||
merged_tensors[key] = merged_tensor
|
||||
if was_merged:
|
||||
|
||||
@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
SkipEvalOnResumeCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.resume_from_checkpoint:
|
||||
callbacks.append(SkipEvalOnResumeCallback())
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
|
||||
@@ -100,6 +100,27 @@ class AxolotlTrainer(
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
|
||||
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
|
||||
# "the model handles num_items_in_batch internally" and skips the loss ÷
|
||||
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
|
||||
# (the gradient itself is still correct). Override to False when the model
|
||||
# doesn't actually consume num_items_in_batch.
|
||||
if self.model_accepts_loss_kwargs:
|
||||
model_to_check = self.accelerator.unwrap_model(self.model)
|
||||
if hasattr(model_to_check, "base_model"): # PEFT wrapper
|
||||
model_to_check = model_to_check.base_model
|
||||
if hasattr(model_to_check, "model"):
|
||||
model_to_check = model_to_check.model
|
||||
fwd = getattr(model_to_check, "forward", None)
|
||||
if fwd is not None:
|
||||
import inspect
|
||||
|
||||
params = inspect.signature(fwd).parameters
|
||||
if "num_items_in_batch" not in params:
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
@@ -383,13 +404,29 @@ class AxolotlTrainer(
|
||||
|
||||
# Gemma4 requires mm_token_type_ids during training (even for text-only).
|
||||
# Inject zeros (= text token type) when not provided by the data collator.
|
||||
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
|
||||
_unwrapped = self.accelerator.unwrap_model(model)
|
||||
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
|
||||
and _model_type == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
|
||||
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
|
||||
# from position_ids, but only when attention_mask is None. When sample
|
||||
# packing is active the collator provides an all-ones attention_mask that
|
||||
# prevents this detection — remove it so the model builds the correct
|
||||
# per-sequence causal masks.
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and _model_type in ("gemma4", "gemma3")
|
||||
and "attention_mask" in inputs
|
||||
and "position_ids" in inputs
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -398,6 +435,23 @@ class AxolotlTrainer(
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
# Gemma4ForConditionalGeneration computes loss with a manual
|
||||
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
|
||||
# normalization and does redundant attention_mask filtering.
|
||||
# Compute loss externally using the standard loss_function instead.
|
||||
if _model_type == "gemma4" and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
unwrapped = self.accelerator.unwrap_model(model)
|
||||
vocab_size = unwrapped.config.get_text_config().vocab_size
|
||||
loss = unwrapped.loss_function(
|
||||
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
if return_outputs:
|
||||
return loss, outputs
|
||||
return loss
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
@@ -410,6 +464,21 @@ class AxolotlTrainer(
|
||||
LOG.info("Running evaluation step...")
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
# Gemma4 requires mm_token_type_ids even during evaluation.
|
||||
_unwrapped = self.accelerator.unwrap_model(model)
|
||||
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
|
||||
if (
|
||||
"mm_token_type_ids" not in inputs
|
||||
and "input_ids" in inputs
|
||||
and _model_type == "gemma4"
|
||||
):
|
||||
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
|
||||
return super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -44,6 +44,7 @@ plugins:
|
||||
- gemma3_text
|
||||
- gemma3n
|
||||
- gemma3n_text
|
||||
- gemma4
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||
'`pip uninstall -y cut-cross-entropy && pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
|
||||
|
||||
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
|
||||
|
||||
**Important limitations:**
|
||||
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
|
||||
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
|
||||
|
||||
## Limitations
|
||||
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
|
||||
@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_sonicmoe_lora_overhead(cls, data):
|
||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
||||
"lora",
|
||||
"qlora",
|
||||
):
|
||||
lora_target = data.get("lora_target_modules") or []
|
||||
lora_linear = data.get("lora_target_linear_modules") or []
|
||||
targets = (
|
||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
||||
if any(kw in t for t in targets for kw in expert_keywords):
|
||||
LOG.info(
|
||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel(cls, data):
|
||||
|
||||
@@ -60,49 +60,14 @@ def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
||||
"""Convert peft LoRA weights to scattermoe layout.
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
||||
are swapped relative to scattermoe's convention.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
||||
Uses vectorized tensor operations (no Python loop over experts).
|
||||
|
||||
Works for **both** gate_up_proj and down_proj since the transposition
|
||||
issue is the same for any parameter.
|
||||
peft >=0.19.1 assigns in/out features for 3D params such that
|
||||
A and B already align with scattermoe's convention (no A<->B swap).
|
||||
Only B needs rank-major → expert-major layout conversion.
|
||||
"""
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
||||
|
||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
||||
smoe_A = (
|
||||
peft_B_em.reshape(dim2, num_experts, rank)
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.reshape(rank * num_experts, dim2)
|
||||
)
|
||||
|
||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
||||
smoe_B = (
|
||||
peft_A.reshape(num_experts, rank, dim1)
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
.reshape(dim1, num_experts * rank)
|
||||
)
|
||||
|
||||
smoe_A = peft_A
|
||||
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
|
||||
@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
swiglu=cfg.liger_glu_activation,
|
||||
)
|
||||
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
|
||||
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
|
||||
if cfg.liger_rms_norm:
|
||||
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
|
||||
|
||||
class _LigerGemma4RMSNorm(LigerRMSNorm):
|
||||
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
|
||||
|
||||
def __new__(cls, dim, eps=1e-6, with_scale=True):
|
||||
if not with_scale:
|
||||
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, dim, eps=1e-6, with_scale=True):
|
||||
if not with_scale:
|
||||
return
|
||||
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
|
||||
super().__init__(
|
||||
dim, eps, offset=0.0, casting_mode="llama", in_place=False
|
||||
)
|
||||
|
||||
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
|
||||
class _LigerGemma4MLP(LigerGEGLUMLP):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__(config)
|
||||
|
||||
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
|
||||
if cfg.liger_rope:
|
||||
LOG.warning(
|
||||
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
|
||||
)
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
LOG.warning(
|
||||
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
|
||||
)
|
||||
LOG.info(
|
||||
f"Applied Liger kernels for gemma4: "
|
||||
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
|
||||
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
|
||||
)
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
try:
|
||||
from .models.base import patch_lce_forward
|
||||
|
||||
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
|
||||
|
||||
Fuses three operations into one kernel launch:
|
||||
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
|
||||
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
3. (optional) RMSNorm without scale (for v_norm)
|
||||
|
||||
This eliminates two intermediate tensor materializations per Q/K path;
|
||||
churn from rotate_half / apply_rotary_pos_emb.
|
||||
|
||||
Shapes:
|
||||
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
|
||||
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
|
||||
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
|
||||
sin: (rows, head_dim) — same as cos
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from liger_kernel.ops.utils import (
|
||||
calculate_settings,
|
||||
compare_version,
|
||||
ensure_contiguous,
|
||||
torch_to_triton_dtype,
|
||||
)
|
||||
from liger_kernel.utils import is_npu_available
|
||||
|
||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||
try:
|
||||
from triton.language.extra.libdevice import rsqrt
|
||||
except ModuleNotFoundError:
|
||||
from triton.language.extra.cuda.libdevice import rsqrt
|
||||
else:
|
||||
from triton.language.math import rsqrt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_rope_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
W_ptr,
|
||||
COS_ptr,
|
||||
COS_row_stride,
|
||||
SIN_ptr,
|
||||
SIN_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused forward:
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||
|
||||
rotate_half swaps first/second halves and negates the first:
|
||||
rotate_half([a, b]) = [-b, a]
|
||||
|
||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||
cs_row_idx = row_idx // n_heads
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
|
||||
# Load input row
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
# RMSNorm: compute 1/rms
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
# Normalize
|
||||
X_norm = X_fp32 * rstd
|
||||
|
||||
# Apply weight if present (with_scale=True)
|
||||
if HAS_WEIGHT:
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_norm = X_norm * W_row
|
||||
|
||||
# RoPE: load cos/sin (broadcast across heads)
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
sin_row = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
|
||||
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||
# for col >= half_dim, take X_norm[col - half_dim]
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
X_rot = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||
).to(tl.float32)
|
||||
# Re-normalize the rotated values
|
||||
X_rot_norm = X_rot * rstd
|
||||
if HAS_WEIGHT:
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||
tl.float32
|
||||
)
|
||||
X_rot_norm = X_rot_norm * W_rot
|
||||
|
||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||
X_rot_norm = X_rot_norm * sign
|
||||
|
||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
Y_row = X_norm * cos_row + X_rot_norm * sin_row
|
||||
|
||||
tl.store(
|
||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||
Y_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_rope_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
W_ptr,
|
||||
COS_ptr,
|
||||
COS_row_stride,
|
||||
SIN_ptr,
|
||||
SIN_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
dW_ptr,
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = RoPE(RMSNorm(X, W))
|
||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
row_start = row_block_id * rows_per_program
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
if HAS_WEIGHT:
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
for row_idx in range(row_start, row_end):
|
||||
cs_row_idx = row_idx // n_heads
|
||||
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
|
||||
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||
#
|
||||
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||
# This is equivalent to rotating (dY * sin) because the rotation
|
||||
# just permutes which elements are multiplied.
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
dY_rot = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
sin_rot = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
|
||||
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||
|
||||
# Pre-weight normalized: n = rstd * x
|
||||
n = X_row * rstd
|
||||
|
||||
if HAS_WEIGHT:
|
||||
dW_acc += dN * n
|
||||
dm = dN * W_row
|
||||
else:
|
||||
dm = dN
|
||||
|
||||
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
|
||||
dot_dm_x = tl.sum(dm * X_row, axis=0)
|
||||
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
|
||||
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||
dX_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if HAS_WEIGHT:
|
||||
tl.store(
|
||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||
dW_acc,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
"""
|
||||
Args:
|
||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||
W: (head_dim,) or None — RMSNorm weight
|
||||
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
eps: float
|
||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||
Returns:
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||
"""
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
has_weight = W is not None
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_rope_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
W if has_weight else X, # dummy pointer when no weight
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT=has_weight,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||
n_rows, n_cols = dY.shape
|
||||
has_weight = W is not None
|
||||
|
||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||
rows_per_program = math.ceil(n_rows / sm_count)
|
||||
|
||||
dX = torch.empty_like(X)
|
||||
|
||||
if has_weight:
|
||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
|
||||
else:
|
||||
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_rope_backward_kernel[(sm_count,)](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
W if has_weight else X, # dummy
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
_dW,
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT=has_weight,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
|
||||
return dX, dW
|
||||
|
||||
|
||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||
"""
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, head_dim) — broadcast across heads
|
||||
sin: (B*S, head_dim) — broadcast across heads
|
||||
n_heads: int
|
||||
"""
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||
X,
|
||||
W,
|
||||
cos,
|
||||
sin,
|
||||
eps,
|
||||
n_heads,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.n_heads = n_heads
|
||||
ctx.has_weight = W is not None
|
||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, W, cos, sin, RSTD = ctx.saved_tensors
|
||||
dX, dW = rms_norm_rope_backward(
|
||||
dY,
|
||||
X,
|
||||
W,
|
||||
cos,
|
||||
sin,
|
||||
RSTD,
|
||||
ctx.n_heads,
|
||||
ctx.BLOCK_SIZE,
|
||||
ctx.num_warps,
|
||||
)
|
||||
return dX, dW, None, None, None, None
|
||||
|
||||
|
||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
Apply fused RMSNorm + RoPE.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
eps: float — RMSNorm epsilon
|
||||
|
||||
Returns:
|
||||
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
|
||||
"""
|
||||
shape = x.shape # (B, S, H, D)
|
||||
B, S, H, D = shape
|
||||
# Flatten to 2D: (B*S*H, D)
|
||||
x_flat = x.reshape(-1, D).contiguous()
|
||||
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||
# by dividing the row_idx by H to get the cos/sin row
|
||||
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||
return y_flat.view(shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""RMSNorm without scale weight: y = x / rms(x)"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
Y_row = X_fp32 * rstd
|
||||
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_noscale_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Backward for y = x * rstd (no weight)."""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
|
||||
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
|
||||
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
|
||||
)
|
||||
|
||||
|
||||
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
|
||||
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, eps):
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
Y = torch.empty_like(X)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, RSTD)
|
||||
ctx.n_cols = n_cols
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, RSTD = ctx.saved_tensors
|
||||
n_rows = X.shape[0]
|
||||
dX = torch.empty_like(X)
|
||||
_rms_norm_noscale_backward_kernel[(n_rows,)](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
ctx.n_cols,
|
||||
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
return dX, None
|
||||
|
||||
|
||||
def fused_rms_norm_noscale(x, eps=1e-6):
|
||||
"""
|
||||
RMSNorm without scale for v_norm.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim)
|
||||
Returns:
|
||||
y: same shape, normalized
|
||||
"""
|
||||
shape = x.shape
|
||||
x_flat = x.reshape(-1, shape[-1])
|
||||
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
|
||||
return y_flat.view(shape)
|
||||
@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_QK(torch.autograd.Function):
|
||||
"""Optimized LoRA QK implementation for models where v_proj is None.
|
||||
|
||||
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
|
||||
reused as value states. Only Q and K projections are fused; the caller
|
||||
returns K a second time as V so that autograd accumulates key+value gradients
|
||||
into a single dK.
|
||||
|
||||
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
X_drop: torch.Tensor | None,
|
||||
# Q params
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor | None,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
q_lora_bias: torch.Tensor | None,
|
||||
q_magnitude: torch.Tensor | None,
|
||||
# K params
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor | None,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
k_lora_bias: torch.Tensor | None,
|
||||
k_magnitude: torch.Tensor | None,
|
||||
# Flags
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
has_dropout = X_drop is not None
|
||||
has_dora = q_magnitude is not None
|
||||
|
||||
if has_dora:
|
||||
dtype = X.dtype
|
||||
X_lora = X_drop if has_dropout else X
|
||||
|
||||
# Compute Q with DoRA
|
||||
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
|
||||
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
|
||||
q_mag_scale = _compute_dora_scale(
|
||||
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
|
||||
)
|
||||
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
|
||||
if q_bias is not None:
|
||||
Q = Q + q_bias
|
||||
|
||||
# Compute K with DoRA
|
||||
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
|
||||
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
|
||||
k_mag_scale = _compute_dora_scale(
|
||||
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
|
||||
)
|
||||
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
|
||||
if k_bias is not None:
|
||||
K = K + k_bias
|
||||
|
||||
Q_combined = Q_base + Q_lora
|
||||
K_combined = K_base + K_lora
|
||||
|
||||
ctx.save_for_backward(
|
||||
X,
|
||||
X_drop if has_dropout else X,
|
||||
q_A.to(dtype) if q_A is not None else q_A,
|
||||
q_B.to(dtype) if q_B is not None else q_B,
|
||||
k_A.to(dtype) if k_A is not None else k_A,
|
||||
k_B.to(dtype) if k_B is not None else k_B,
|
||||
q_magnitude,
|
||||
k_magnitude,
|
||||
q_mag_scale,
|
||||
k_mag_scale,
|
||||
Q_combined,
|
||||
K_combined,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
)
|
||||
else:
|
||||
# Standard LoRA (with optional dropout and bias)
|
||||
Q = matmul_lora(
|
||||
X,
|
||||
q_weight,
|
||||
q_bias,
|
||||
q_quant,
|
||||
q_A,
|
||||
q_B,
|
||||
q_scale,
|
||||
X_drop=X_drop,
|
||||
lora_bias=q_lora_bias,
|
||||
)
|
||||
K = matmul_lora(
|
||||
X,
|
||||
k_weight,
|
||||
k_bias,
|
||||
k_quant,
|
||||
k_A,
|
||||
k_B,
|
||||
k_scale,
|
||||
X_drop=X_drop,
|
||||
lora_bias=k_lora_bias,
|
||||
)
|
||||
|
||||
dtype = X.dtype
|
||||
ctx.save_for_backward(
|
||||
X,
|
||||
X_drop if has_dropout else X,
|
||||
q_A.to(dtype) if q_A is not None else q_A,
|
||||
q_B.to(dtype) if q_B is not None else q_B,
|
||||
k_A.to(dtype) if k_A is not None else k_A,
|
||||
k_B.to(dtype) if k_B is not None else k_B,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
)
|
||||
|
||||
ctx.scales = (q_scale, k_scale)
|
||||
ctx.quants = (q_quant, k_quant)
|
||||
ctx.weights = (q_weight, k_weight)
|
||||
ctx.inplace = inplace
|
||||
ctx.has_dropout = has_dropout
|
||||
ctx.has_dora = has_dora
|
||||
|
||||
return Q, K
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
):
|
||||
q_weight, k_weight = ctx.weights
|
||||
q_quant, k_quant = ctx.quants
|
||||
q_scale, k_scale = ctx.scales
|
||||
has_dropout = ctx.has_dropout
|
||||
has_dora = ctx.has_dora
|
||||
|
||||
if has_dora:
|
||||
(
|
||||
X,
|
||||
X_lora,
|
||||
A_q,
|
||||
B_q,
|
||||
A_k,
|
||||
B_k,
|
||||
q_magnitude,
|
||||
k_magnitude,
|
||||
q_mag_scale,
|
||||
k_mag_scale,
|
||||
Q_combined,
|
||||
K_combined,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
) = ctx.saved_tensors
|
||||
else:
|
||||
(
|
||||
X,
|
||||
X_lora,
|
||||
A_q,
|
||||
B_q,
|
||||
A_k,
|
||||
B_k,
|
||||
q_lora_bias,
|
||||
k_lora_bias,
|
||||
) = ctx.saved_tensors
|
||||
q_magnitude = k_magnitude = None
|
||||
q_mag_scale = k_mag_scale = None
|
||||
Q_combined = K_combined = None
|
||||
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
X_lora = X_lora.view(-1, X_lora.shape[-1])
|
||||
|
||||
d_q_mag = d_k_mag = None
|
||||
d_q_lora_bias = d_k_lora_bias = None
|
||||
|
||||
if has_dora:
|
||||
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
|
||||
K_combined = K_combined.view(-1, K_combined.shape[-1])
|
||||
|
||||
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
|
||||
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
|
||||
|
||||
q_grad = q_grad * q_mag_scale.unsqueeze(0)
|
||||
k_grad = k_grad * k_mag_scale.unsqueeze(0)
|
||||
|
||||
# LoRA bias gradients
|
||||
if q_lora_bias is not None:
|
||||
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
|
||||
if k_lora_bias is not None:
|
||||
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
|
||||
|
||||
X_lora_t = X_lora.t()
|
||||
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = None
|
||||
grad_B_q = grad_B_k = None
|
||||
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_B_q = q_grad @ B_q
|
||||
d_A_q = torch.empty_like(A_q.t())
|
||||
d_B_q = torch.empty_like(B_q.t())
|
||||
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
|
||||
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
|
||||
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_B_k = k_grad @ B_k
|
||||
d_A_k = torch.empty_like(A_k.t())
|
||||
d_B_k = torch.empty_like(B_k.t())
|
||||
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
|
||||
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
|
||||
|
||||
# Base path input gradient
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight_t
|
||||
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight_t
|
||||
|
||||
# LoRA path input gradient
|
||||
if has_dropout:
|
||||
grad_X_drop = torch.zeros_like(X_lora)
|
||||
if grad_B_q is not None:
|
||||
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||
if grad_B_k is not None:
|
||||
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||
else:
|
||||
grad_X_drop = None
|
||||
if grad_B_q is not None:
|
||||
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
|
||||
if grad_B_k is not None:
|
||||
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
|
||||
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||
|
||||
grad_X = grad_X.view(batch, seq_len, -1)
|
||||
if grad_X_drop is not None:
|
||||
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
|
||||
|
||||
# Return gradients for all forward inputs:
|
||||
# X, X_drop,
|
||||
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
# inplace
|
||||
return (
|
||||
grad_X,
|
||||
grad_X_drop,
|
||||
# Q
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
d_q_lora_bias,
|
||||
d_q_mag,
|
||||
# K
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
d_k_lora_bias,
|
||||
d_k_mag,
|
||||
# inplace
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qk(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query and Key projections for models where v_proj is None.
|
||||
|
||||
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
|
||||
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
|
||||
Because K is returned twice, autograd accumulates gradients from both the key and
|
||||
value paths into dK before calling LoRA_QK.backward.
|
||||
|
||||
Supports bias, dropout, and DoRA.
|
||||
"""
|
||||
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
|
||||
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
|
||||
|
||||
# Apply dropout outside autograd.Function (shared mask for Q, K)
|
||||
X_drop = _apply_dropout(Qdrop, X, self.training)
|
||||
|
||||
Q, K = LoRA_QK.apply(
|
||||
X,
|
||||
X_drop,
|
||||
# Q
|
||||
QW,
|
||||
Qb,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
Qlb,
|
||||
Qmag,
|
||||
# K
|
||||
KW,
|
||||
Kb,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
Klb,
|
||||
Kmag,
|
||||
# Flags
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, K
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection.
|
||||
|
||||
|
||||
@@ -67,12 +67,165 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def _patch_peft_clippable_linear():
|
||||
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
|
||||
|
||||
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
|
||||
that clips activations). PEFT doesn't recognise it as a supported layer type,
|
||||
so we redirect LoRA injection to the inner ``.linear`` child instead.
|
||||
"""
|
||||
try:
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4ClippableLinear as _cls,
|
||||
)
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
from peft.tuners.lora.model import LoraModel
|
||||
|
||||
if getattr(LoraModel, "_axolotl_clippable_patched", False):
|
||||
return
|
||||
_orig = LoraModel._create_and_replace
|
||||
|
||||
def _patched(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=None,
|
||||
**kw,
|
||||
):
|
||||
if isinstance(target, _cls):
|
||||
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
|
||||
return _orig(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target.linear,
|
||||
"linear",
|
||||
target,
|
||||
current_key=current_key,
|
||||
**kw,
|
||||
)
|
||||
return _orig(
|
||||
self,
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=current_key,
|
||||
**kw,
|
||||
)
|
||||
|
||||
LoraModel._create_and_replace = _patched
|
||||
LoraModel._axolotl_clippable_patched = True
|
||||
|
||||
|
||||
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
|
||||
"""Check whether PEFT will auto-populate target_parameters for this model.
|
||||
|
||||
PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE
|
||||
``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into
|
||||
``target_parameters`` (``gate_up_proj``/``down_proj``) because
|
||||
transformers v5 fused those expert linears into 3D ``nn.Parameter``
|
||||
tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``,
|
||||
which rejects ``lora_dropout != 0``. This probe runs the conversion on
|
||||
a copy of the config so we can detect the situation before
|
||||
``get_peft_model`` blows up.
|
||||
"""
|
||||
if getattr(lora_config, "target_parameters", None):
|
||||
return False
|
||||
|
||||
try:
|
||||
from peft.utils.transformers_weight_conversion import (
|
||||
convert_peft_config_for_transformers,
|
||||
get_model_conversion_mapping,
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
import copy
|
||||
|
||||
probe_cfg = copy.deepcopy(lora_config)
|
||||
try:
|
||||
convert_peft_config_for_transformers(
|
||||
probe_cfg,
|
||||
model=model,
|
||||
conversions=get_model_conversion_mapping(model),
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return False
|
||||
|
||||
return bool(getattr(probe_cfg, "target_parameters", None))
|
||||
|
||||
|
||||
def _patch_peft_param_wrapper_dropout():
|
||||
"""Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``.
|
||||
|
||||
``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects
|
||||
non-zero dropout because dropout can't be factored out of
|
||||
``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed
|
||||
matmul. For mixed configs (attention + MoE experts) this is too
|
||||
aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout
|
||||
and that's usually what the user intended. We pass a copy of the
|
||||
``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__``
|
||||
so it builds with ``nn.Identity`` for its internal dropout slot while
|
||||
every other layer type still receives the real dropout value.
|
||||
"""
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
if getattr(ParamWrapper, "_axolotl_dropout_patched", False):
|
||||
return
|
||||
|
||||
_orig_init = ParamWrapper.__init__
|
||||
|
||||
def _patched_init(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
parameter_name,
|
||||
config,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if getattr(config, "lora_dropout", 0):
|
||||
import copy as _copy
|
||||
|
||||
patched_config = _copy.copy(config)
|
||||
patched_config.lora_dropout = 0.0
|
||||
return _orig_init(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
parameter_name,
|
||||
patched_config,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return _orig_init(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
parameter_name,
|
||||
config,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ParamWrapper.__init__ = _patched_init
|
||||
ParamWrapper._axolotl_dropout_patched = True
|
||||
|
||||
|
||||
def load_lora(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
_patch_peft_clippable_linear()
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
@@ -124,6 +277,7 @@ def load_lora(
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=task_type,
|
||||
**lora_config_kwargs,
|
||||
@@ -132,6 +286,20 @@ def load_lora(
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
|
||||
if getattr(
|
||||
lora_config, "lora_dropout", 0
|
||||
) and _peft_will_auto_convert_target_params(model, lora_config):
|
||||
LOG.warning(
|
||||
"lora_dropout=%s requested but PEFT will wrap this model's fused "
|
||||
"MoE expert parameters with ParamWrapper, which cannot apply "
|
||||
"dropout (the 3D einsum can't factor dropout out of "
|
||||
"lora_B(lora_A(dropout(x)))). Dropout will still be applied to "
|
||||
"non-expert LoRA layers (e.g. attention), and expert LoRA layers "
|
||||
"will use nn.Identity for the dropout slot.",
|
||||
lora_config.lora_dropout,
|
||||
)
|
||||
_patch_peft_param_wrapper_dropout()
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if (
|
||||
|
||||
@@ -547,6 +547,16 @@ class ModelLoader:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
|
||||
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
|
||||
from transformers import FineGrainedFP8Config
|
||||
|
||||
fp8_kwargs = {}
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
fp8_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
|
||||
**fp8_kwargs
|
||||
)
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
LOG.warning(
|
||||
@@ -624,7 +634,14 @@ class ModelLoader:
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.attn_implementation:
|
||||
if self.cfg.gemma4_hybrid_attn_impl:
|
||||
# Load model with flash_attention_2 for sliding window layers;
|
||||
# global layers will be patched to sdpa post-load.
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
# Set flash_attention so multipack/sample_packing patches activate
|
||||
self.cfg.flash_attention = True
|
||||
elif self.cfg.attn_implementation:
|
||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
|
||||
@@ -156,15 +156,81 @@ class PatchManager:
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
self._apply_gemma_hybrid_attention(model)
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
self._apply_unsloth_patches(model)
|
||||
self._apply_lora_kernel_patch(model)
|
||||
self._apply_scaling_softmax_patch(model)
|
||||
|
||||
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
|
||||
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
|
||||
|
||||
Gemma 4 has global (full_attention) layers with head_dim=512
|
||||
which exceeds flash attention's supported size. This patch loads the model
|
||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||
"""
|
||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||
return
|
||||
|
||||
import copy
|
||||
|
||||
# Navigate to the module that has 'layers' - varies by model structure:
|
||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||
layers = None
|
||||
config_source = None
|
||||
for candidate in [model, getattr(model, "model", None)]:
|
||||
if candidate is None:
|
||||
continue
|
||||
# Check direct layers
|
||||
if hasattr(candidate, "layers"):
|
||||
layers = candidate.layers
|
||||
config_source = candidate
|
||||
break
|
||||
# Check language_model.layers (multimodal wrapper)
|
||||
lang_model = getattr(candidate, "language_model", None)
|
||||
if lang_model is not None and hasattr(lang_model, "layers"):
|
||||
layers = lang_model.layers
|
||||
config_source = lang_model
|
||||
break
|
||||
|
||||
if layers is None:
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
config = getattr(config_source, "config", self.model_config)
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
if layer_types is None:
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
|
||||
"This feature requires a model with mixed sliding/global attention layers."
|
||||
)
|
||||
return
|
||||
|
||||
patched_count = 0
|
||||
for layer_idx, layer in enumerate(layers):
|
||||
if layer_types[layer_idx] != "sliding_attention":
|
||||
# Global / full_attention layer - use SDPA instead of FA2
|
||||
attn_module = getattr(layer, "self_attn", None)
|
||||
if attn_module is not None and hasattr(attn_module, "config"):
|
||||
sdpa_config = copy.copy(attn_module.config)
|
||||
sdpa_config._attn_implementation = "sdpa"
|
||||
attn_module.config = sdpa_config
|
||||
patched_count += 1
|
||||
|
||||
LOG.info(
|
||||
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
|
||||
"(remaining %d sliding layers use flash_attention_2)",
|
||||
patched_count,
|
||||
len(layers) - patched_count,
|
||||
)
|
||||
|
||||
def _apply_flash_attention_patches(self):
|
||||
"""Apply patches related to Flash Attention."""
|
||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||
@@ -324,6 +390,22 @@ class PatchManager:
|
||||
|
||||
patch_qwen3_5_vlm_flash_attention()
|
||||
|
||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
# Shared-KV side channel when activation checkpointing (PR #3611).
|
||||
fsdp_cfg = self.cfg.fsdp_config
|
||||
needs_shared_kv_workaround = (not self.inference) and bool(
|
||||
self.cfg.gradient_checkpointing
|
||||
or self.cfg.activation_offloading
|
||||
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
|
||||
)
|
||||
patch_gemma4_fused_attn(
|
||||
install_shared_kv_workaround=needs_shared_kv_workaround
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fix_nemotron_h_conversion_mapping():
|
||||
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||
@@ -600,24 +682,10 @@ class PatchManager:
|
||||
)
|
||||
|
||||
patch_fa_llama_cross_entropy()
|
||||
elif self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
|
||||
|
||||
patch_llama_rms_norm()
|
||||
elif self.cfg.unsloth_rms_norm:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def _patch_llama_flash_attention(self):
|
||||
"""Apply Flash Attention patches for LLaMA models."""
|
||||
@@ -684,23 +752,6 @@ class PatchManager:
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
def _apply_unsloth_patches(self, model):
|
||||
"""Apply unsloth optimization patches."""
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
integrate_lora_mlp_patch(peft_model=model)
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
||||
|
||||
integrate_lora_patch(peft_model=model, cfg=self.cfg)
|
||||
|
||||
if self.cfg.unsloth_rope:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
||||
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def _apply_lora_kernel_patch(self, model):
|
||||
"""Apply LoRA kernel patches."""
|
||||
if (
|
||||
|
||||
@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
LOG.warning(
|
||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||
tokenizer.eos_token,
|
||||
)
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
@@ -303,6 +295,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
# Generic fallback: if tokenizer still has no pad_token, use eos_token
|
||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
LOG.warning(
|
||||
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
|
||||
tokenizer.eos_token,
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
|
||||
@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
# Clone the local shard to allow full_tensor to be freed.
|
||||
if (
|
||||
sharded_param._local_tensor.untyped_storage().size()
|
||||
> sharded_param._local_tensor.nelement()
|
||||
* sharded_param._local_tensor.element_size()
|
||||
):
|
||||
sharded_param = sharded_param.clone()
|
||||
else:
|
||||
# Non-sharded parameters
|
||||
if _accelerator.is_main_process:
|
||||
|
||||
@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
|
||||
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||
return
|
||||
|
||||
try:
|
||||
# flash-attn-4>=4.0.0b7
|
||||
from flash_attn.cute import flash_attn_with_kvcache
|
||||
except ImportError:
|
||||
flash_attn_with_kvcache = None
|
||||
|
||||
def _patched_lazy_imports(
|
||||
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||
):
|
||||
return (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
fa_utils._pad_input,
|
||||
fa_utils._unpad_input,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qk,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
@@ -111,6 +112,47 @@ QKV_PATCHES = [
|
||||
else:
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
|
||||
# past_key_values.shared_layers, and v_norm added after k_norm.
|
||||
(
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
query_states = self.q_norm(query_states)
|
||||
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
|
||||
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
|
||||
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
@@ -483,18 +525,24 @@ def apply_lora_kernel_patches(
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
|
||||
proj_names = ["q_proj", "k_proj", "v_proj"]
|
||||
layer_modules = [
|
||||
getattr(self_attn, name)
|
||||
for name in proj_names
|
||||
if getattr(self_attn, name, None) is not None
|
||||
]
|
||||
has_v_proj = getattr(self_attn, "v_proj", None) is not None
|
||||
proj_names = (
|
||||
["q_proj", "k_proj", "v_proj"]
|
||||
if has_v_proj
|
||||
else ["q_proj", "k_proj"]
|
||||
)
|
||||
layer_modules = [getattr(self_attn, name) for name in proj_names]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A") for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
if has_v_proj:
|
||||
self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, self_attn
|
||||
)
|
||||
else:
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||
|
||||
194
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
194
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Gemma 4 fused attention monkeypatch.
|
||||
|
||||
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
|
||||
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
|
||||
|
||||
Usage:
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
|
||||
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
|
||||
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
|
||||
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
|
||||
|
||||
|
||||
def _set_shared_kv_states(store):
|
||||
_GEMMA4_SHARED_KV_STORE["store"] = store
|
||||
|
||||
|
||||
def _get_shared_kv_states():
|
||||
return _GEMMA4_SHARED_KV_STORE["store"]
|
||||
|
||||
|
||||
def _make_fused_forward(original_forward):
|
||||
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import (
|
||||
fused_rms_norm_noscale,
|
||||
fused_rms_norm_rope,
|
||||
)
|
||||
|
||||
def fused_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
eager_attention_forward,
|
||||
)
|
||||
|
||||
store = _get_shared_kv_states()
|
||||
if store is not None:
|
||||
shared_kv_states = store
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
eps = self.config.rms_norm_eps
|
||||
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# ---- Projections ----
|
||||
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
|
||||
has_lora_qkv = hasattr(self, "apply_qkv")
|
||||
|
||||
if has_lora_qkv:
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# ---- Q path: fused q_norm + RoPE ----
|
||||
query_states = fused_rms_norm_rope(
|
||||
query_states,
|
||||
self.q_norm.weight,
|
||||
cos,
|
||||
sin,
|
||||
eps=eps,
|
||||
)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# ---- K/V path ----
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
if has_lora_qkv:
|
||||
# apply_qkv already computed k/v projections
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = (
|
||||
value_states.view(hidden_shape)
|
||||
if self.v_proj is not None
|
||||
else key_states
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states).view(hidden_shape)
|
||||
if self.v_proj is not None
|
||||
else key_states
|
||||
)
|
||||
|
||||
# Fused k_norm + RoPE
|
||||
key_states = fused_rms_norm_rope(
|
||||
key_states,
|
||||
self.k_norm.weight,
|
||||
cos,
|
||||
sin,
|
||||
eps=eps,
|
||||
)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
# Fused v_norm (no scale, no RoPE)
|
||||
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||
self.config._attn_implementation
|
||||
]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=self.attention_dropout if self.training else 0.0,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
return fused_forward
|
||||
|
||||
|
||||
def _patch_decoder_layer_call():
|
||||
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
|
||||
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
|
||||
|
||||
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
|
||||
return
|
||||
|
||||
original_call = Gemma4TextDecoderLayer.__call__
|
||||
|
||||
def patched_call(self, *args, **kwargs):
|
||||
shared_kv = kwargs.pop("shared_kv_states", None)
|
||||
# Overwrite unconditionally (including with None) so a previous step's
|
||||
# dict cannot leak into a later call without shared_kv_states (PR #3611).
|
||||
_set_shared_kv_states(shared_kv)
|
||||
return original_call(self, *args, **kwargs)
|
||||
|
||||
Gemma4TextDecoderLayer.__call__ = patched_call
|
||||
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
|
||||
and optionally route `shared_kv_states` via a module-level side channel to
|
||||
avoid a VRAM leak under activation checkpointing (PR #3611).
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
original_forward = Gemma4TextAttention.forward
|
||||
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||
|
||||
if install_shared_kv_workaround:
|
||||
_patch_decoder_layer_call()
|
||||
|
||||
logger.info(
|
||||
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||
)
|
||||
if install_shared_kv_workaround:
|
||||
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")
|
||||
@@ -1,252 +0,0 @@
|
||||
"""module for patching with unsloth optimizations"""
|
||||
|
||||
import inspect
|
||||
import types
|
||||
|
||||
import torch
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
""".lstrip("\n")
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
||||
""".lstrip("\n")
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip("\n")
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(self, attn_output)
|
||||
""".lstrip("\n")
|
||||
|
||||
|
||||
def original_apply_qkv(self, hidden_states):
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self, hidden_states):
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_self_attn_code() -> str:
|
||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||
return forward
|
||||
|
||||
|
||||
def check_self_attn_is_patchable() -> bool:
|
||||
qkv = get_self_attn_code()
|
||||
qkv, _ = detab_code(qkv)
|
||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
||||
|
||||
def UnslothForCausalLMLoss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
||||
)
|
||||
return loss
|
||||
|
||||
if model_type == "llama":
|
||||
from transformers.loss import loss_utils
|
||||
|
||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
self_attn_lora_patched = False
|
||||
|
||||
|
||||
def patch_self_attn_lora():
|
||||
global self_attn_lora_patched
|
||||
if self_attn_lora_patched:
|
||||
# prevent patching multiple times
|
||||
return
|
||||
self_attn_forward = get_self_attn_code()
|
||||
LlamaFlashAttention2._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def unsloth_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec(
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals())
|
||||
self_attn_lora_patched = True
|
||||
LOG.info("patching unsloth attn lora")
|
||||
LlamaFlashAttention2.forward = unsloth_attn_forward
|
||||
|
||||
|
||||
def integrate_rope_embeddings():
|
||||
import transformers.models.llama.modeling_llama
|
||||
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
position_ids=None,
|
||||
unsqueeze_dim=1,
|
||||
):
|
||||
return fast_rope_embedding(q, k, cos, sin)
|
||||
|
||||
LOG.info("patching unsloth RoPE embeddings")
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
|
||||
|
||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||
|
||||
apply_lora_mlp = apply_lora_mlp_swiglu
|
||||
elif peft_model.base_model.config.model_type == "gemma":
|
||||
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
||||
|
||||
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {peft_model.base_model.config.model_type} not supported"
|
||||
)
|
||||
|
||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||
layer_modules = [
|
||||
getattr(layer.mlp, linear_proj)
|
||||
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
||||
]
|
||||
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
mlp_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||
else:
|
||||
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
|
||||
|
||||
|
||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
||||
|
||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||
if cfg.unsloth_lora_qkv:
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
qkv_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||
else:
|
||||
layer.self_attn.apply_qkv = original_apply_qkv
|
||||
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
|
||||
if cfg.unsloth_lora_o:
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
o_no_bias = all(
|
||||
getattr(module, "base_layer", module).bias is None
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if is_o_lora and o_no_bias and o_not_dora:
|
||||
layer.self_attn.apply_o = apply_lora_o
|
||||
else:
|
||||
layer.self_attn.apply_o = original_apply_o
|
||||
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
|
||||
|
||||
|
||||
def patch_unsloth_layernorm():
|
||||
try:
|
||||
import transformers.models.llama.modeling_llama
|
||||
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
"""LlamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return Fast_RMS_Layernorm.apply(
|
||||
hidden_states, self.weight, self.variance_epsilon, False
|
||||
)
|
||||
|
||||
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.warning("missing unsloth library")
|
||||
@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
self._validate_eot_and_eos_tokens()
|
||||
|
||||
# Pre-cache EOT token IDs to avoid re-encoding on every call
|
||||
self._eot_token_ids = set()
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) == 1:
|
||||
self._eot_token_ids.add(token_ids[0])
|
||||
|
||||
def _validate_eot_and_eos_tokens(self):
|
||||
"""
|
||||
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||
@@ -471,6 +478,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
content = turn.get("content")
|
||||
train_turn = turn.get("training")
|
||||
train_detail = turn.get("training_detail")
|
||||
reasoning_train_detail = turn.get("reasoning_training_detail")
|
||||
|
||||
LOG.debug(
|
||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||
@@ -479,8 +487,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
should_train = None
|
||||
if train_turn is not None:
|
||||
should_train = train_turn
|
||||
elif train_detail is not None:
|
||||
should_train = bool(train_detail)
|
||||
elif train_detail is not None or reasoning_train_detail is not None:
|
||||
should_train = bool(train_detail) or bool(reasoning_train_detail)
|
||||
else:
|
||||
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||
|
||||
@@ -500,15 +508,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
continue
|
||||
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
has_reasoning = thinking_key and turn.get(thinking_key) is not None
|
||||
has_any_detail = train_detail or reasoning_train_detail
|
||||
|
||||
# When train_detail is present and the turn has reasoning_content,
|
||||
# use content_only=True so find_turn returns content-only boundaries
|
||||
# (excluding reasoning_content + template separator tokens).
|
||||
use_content_only = bool(has_any_detail and has_reasoning)
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
turns=turns, turn_idx=index, tools=tools
|
||||
turns=turns,
|
||||
turn_idx=index,
|
||||
tools=tools,
|
||||
content_only=use_content_only,
|
||||
)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||
if train_detail:
|
||||
# Block multi-content for now
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
"`train_detail` is not supported when `content` is not a string."
|
||||
@@ -526,7 +545,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
LOG.debug(
|
||||
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||
)
|
||||
else:
|
||||
elif not reasoning_train_detail:
|
||||
# No per-part detail on either field — train the whole span
|
||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||
turn_start_idx:turn_end_idx
|
||||
]
|
||||
@@ -534,6 +554,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||
)
|
||||
|
||||
# Handle reasoning_content training_detail separately
|
||||
if should_train and reasoning_train_detail and has_reasoning:
|
||||
reasoning_text = turn[thinking_key]
|
||||
if not isinstance(reasoning_text, str):
|
||||
raise ValueError(
|
||||
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
|
||||
)
|
||||
|
||||
reasoning_start, reasoning_end = self.find_turn(
|
||||
turns=turns,
|
||||
turn_idx=index,
|
||||
tools=tools,
|
||||
reasoning_only=True,
|
||||
)
|
||||
|
||||
if reasoning_start != -1 and reasoning_end != -1:
|
||||
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
||||
reasoning_text, reasoning_train_detail
|
||||
)
|
||||
LOG.debug(f"Reasoning token offsets: {token_offsets}")
|
||||
for i, offset in enumerate(token_offsets):
|
||||
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
|
||||
input_ids
|
||||
):
|
||||
labels[reasoning_start + i] = input_ids[reasoning_start + i]
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle special tokens (EOT and EOS)
|
||||
@@ -593,28 +639,31 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# Get token IDs for all EOT tokens
|
||||
eot_token_ids = []
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) != 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
||||
)
|
||||
|
||||
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
||||
|
||||
# Search for any of the EOT token IDs
|
||||
# Use pre-cached EOT token IDs (computed once in __init__)
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] in eot_token_ids:
|
||||
if input_ids[i] in self._eot_token_ids:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(
|
||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||
self,
|
||||
turns: list[dict],
|
||||
turn_idx: int,
|
||||
tools: list[dict] | None = None,
|
||||
content_only: bool = False,
|
||||
reasoning_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
|
||||
Args:
|
||||
content_only: If True and the turn has reasoning_content (template_thinking_key),
|
||||
preserve reasoning_content in the dummy turn so the diff only captures the
|
||||
content field boundaries. This is needed for correct training_detail alignment
|
||||
when reasoning_content is present.
|
||||
reasoning_only: If True, preserve content in the dummy turn and replace
|
||||
reasoning_content with a dummy, so the diff only captures the
|
||||
reasoning_content field boundaries.
|
||||
"""
|
||||
|
||||
if turn_idx >= len(turns):
|
||||
@@ -628,10 +677,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": "[[dummy_message]]",
|
||||
}
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
|
||||
if reasoning_only:
|
||||
# Keep content as-is, replace reasoning with dummy
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": turns[turn_idx].get("content", ""),
|
||||
}
|
||||
if thinking_key and thinking_key in turns[turn_idx]:
|
||||
empty_turn[thinking_key] = "[[dummy_reasoning]]"
|
||||
else:
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": "[[dummy_message]]",
|
||||
}
|
||||
|
||||
# When content_only is True, copy reasoning_content to the dummy turn so
|
||||
# the diff only captures the content field (not reasoning + separator).
|
||||
if content_only and thinking_key and thinking_key in turns[turn_idx]:
|
||||
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
|
||||
|
||||
# Create conversation versions
|
||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||
@@ -697,6 +762,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return start_idx, end_idx
|
||||
|
||||
@staticmethod
|
||||
def _convert_content_parts(
|
||||
content,
|
||||
) -> tuple[str, list[dict] | None] | None:
|
||||
"""Convert list content to concatenated string + optional training_detail.
|
||||
|
||||
When content is a list of dicts (content parts), each part can specify:
|
||||
- ``text``, ``content``, or ``value``: the text string
|
||||
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
|
||||
|
||||
Returns ``(concatenated_text, training_details_or_None)`` if content was
|
||||
a list, or ``None`` if content was not a list (no conversion needed).
|
||||
|
||||
.. note::
|
||||
**Whitespace at part boundaries matters.** BPE tokenizers prepend
|
||||
spaces to word tokens (e.g. ``" answer"`` is one token). Always
|
||||
split BEFORE spaces::
|
||||
|
||||
GOOD: ["Let me think...", " The answer is 4."]
|
||||
BAD: ["Let me think... ", "The answer is 4."]
|
||||
|
||||
Tokens that straddle a boundary are conservatively masked.
|
||||
Newlines typically merge with preceding punctuation (``":\\n"`` is
|
||||
one token), so keep newlines with the preceding part.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return None
|
||||
|
||||
text_parts: list[str] = []
|
||||
training_details: list[dict] = []
|
||||
has_explicit_training = False
|
||||
offset = 0
|
||||
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
# Extract text (HF uses "text", also support "content"/"value")
|
||||
text = (
|
||||
part.get("text") or part.get("content") or part.get("value") or ""
|
||||
)
|
||||
text_parts.append(text)
|
||||
|
||||
# Check for per-part training flags
|
||||
part_train = part.get("train")
|
||||
part_weight = part.get("weight")
|
||||
if part_train is not None or part_weight is not None:
|
||||
has_explicit_training = True
|
||||
train = (
|
||||
part_train
|
||||
if part_train is not None
|
||||
else (part_weight not in (0, 0.0))
|
||||
)
|
||||
else:
|
||||
train = True # default trainable, gated by turn-level should_train
|
||||
|
||||
if text:
|
||||
training_details.append(
|
||||
{
|
||||
"begin_offset": offset,
|
||||
"end_offset": offset + len(text) - 1,
|
||||
"train": train,
|
||||
}
|
||||
)
|
||||
offset += len(text)
|
||||
|
||||
# Warn about trailing whitespace at boundaries between parts with
|
||||
# different training flags — this almost always causes token straddling
|
||||
if has_explicit_training and len(training_details) > 1:
|
||||
for i in range(len(training_details) - 1):
|
||||
cur = training_details[i]
|
||||
nxt = training_details[i + 1]
|
||||
if cur["train"] != nxt["train"]:
|
||||
boundary_text = text_parts[i]
|
||||
if boundary_text and boundary_text[-1] in (" ", "\t"):
|
||||
LOG.warning(
|
||||
"Content part %d ends with whitespace at a train/mask boundary. "
|
||||
"BPE tokenizers typically prepend spaces to word tokens, so "
|
||||
"the space will merge with the next part's first word and the "
|
||||
"resulting token will be MASKED (not trained). Move the "
|
||||
"whitespace to the start of the next content part instead. "
|
||||
"Part text: %r",
|
||||
i,
|
||||
boundary_text[-20:],
|
||||
)
|
||||
|
||||
concatenated = "".join(text_parts)
|
||||
details = training_details if has_explicit_training else None
|
||||
return concatenated, details
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
|
||||
@@ -723,6 +876,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if training_detail is not None:
|
||||
turn["training_detail"] = training_detail
|
||||
|
||||
# Convert list content/reasoning_content to string + auto-generated
|
||||
# training_detail. See _convert_content_parts for whitespace guidance.
|
||||
content_result = self._convert_content_parts(turn.get("content"))
|
||||
if content_result is not None:
|
||||
turn["content"] = content_result[0]
|
||||
if content_result[1] is not None:
|
||||
turn["training_detail"] = content_result[1]
|
||||
|
||||
# Also convert reasoning_content (template_thinking_key) if it's a list
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
if thinking_key and thinking_key in turn:
|
||||
reasoning_result = self._convert_content_parts(turn[thinking_key])
|
||||
if reasoning_result is not None:
|
||||
turn[thinking_key] = reasoning_result[0]
|
||||
if reasoning_result[1] is not None:
|
||||
turn["reasoning_training_detail"] = reasoning_result[1]
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
|
||||
@@ -160,29 +160,16 @@ class TelemetryManager:
|
||||
if not is_main_process():
|
||||
return False
|
||||
|
||||
# Parse relevant env vars
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
def is_truthy_env(var_name: str) -> bool:
|
||||
value = os.getenv(var_name)
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in ("1", "true")
|
||||
|
||||
# Default to enabled (opt-out model)
|
||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
||||
"0",
|
||||
"1",
|
||||
"false",
|
||||
"true",
|
||||
):
|
||||
return True
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
return enabled
|
||||
# Telemetry is enabled by default unless either opt-out var is set
|
||||
return not (
|
||||
is_truthy_env("AXOLOTL_DO_NOT_TRACK") or is_truthy_env("DO_NOT_TRACK")
|
||||
)
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
|
||||
@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
|
||||
):
|
||||
model.enable_input_require_grads()
|
||||
|
||||
# Freeze multimodal modules for text-only training of multimodal models
|
||||
if cfg.freeze_mm_modules:
|
||||
freeze_mm_modules(model)
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
@@ -225,6 +229,28 @@ def execute_training(
|
||||
PLUGIN_MANAGER.post_train(cfg, trainer.model)
|
||||
|
||||
|
||||
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
|
||||
"""Rename model*.safetensors files to adapter_model* in place.
|
||||
|
||||
Also rewrites the index JSON weight_map if sharded output was produced.
|
||||
"""
|
||||
for file in sorted(merged_dir.iterdir()):
|
||||
if file.name.startswith("model") and ".safetensors" in file.name:
|
||||
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
|
||||
|
||||
index = merged_dir / "adapter_model.safetensors.index.json"
|
||||
if index.exists():
|
||||
data = json.loads(index.read_text(encoding="utf-8"))
|
||||
if "weight_map" in data:
|
||||
data["weight_map"] = {
|
||||
k: v.replace("model", "adapter_model", 1)
|
||||
for k, v in data["weight_map"].items()
|
||||
}
|
||||
index.write_text(
|
||||
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
@@ -294,12 +320,17 @@ def save_trained_model(
|
||||
)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
# move all files in merged_path to cfg.output_dir
|
||||
# FSDP checkpoints for PEFT only contain adapter weights;
|
||||
# rename model* → adapter_model* so it loads correctly.
|
||||
is_peft = cfg.adapter and not cfg.relora
|
||||
if is_peft:
|
||||
_rename_fsdp_merged_to_adapter(Path(merged_path))
|
||||
for merged_file in Path(merged_path).iterdir():
|
||||
if (Path(cfg.output_dir) / merged_file.name).exists():
|
||||
(Path(cfg.output_dir) / merged_file.name).unlink()
|
||||
shutil.move(str(merged_file), cfg.output_dir)
|
||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||
dest = Path(cfg.output_dir) / merged_file.name
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
shutil.move(str(merged_file), dest)
|
||||
shutil.rmtree(merged_path)
|
||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||
# cleanup the FSDP prefix in the model config.json
|
||||
if trainer.accelerator.is_main_process:
|
||||
|
||||
@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SkipEvalOnResumeCallback(TrainerCallback):
|
||||
"""Skip the redundant evaluation that fires when resuming from a checkpoint
|
||||
whose step aligns with ``eval_steps``.
|
||||
|
||||
When HuggingFace Trainer resumes, it restores ``global_step`` from the
|
||||
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
|
||||
step. Because the evaluation was already performed during the original
|
||||
run, repeating it wastes time and pollutes metric logs.
|
||||
|
||||
This callback records the ``global_step`` at the start of training (i.e.
|
||||
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
|
||||
any evaluation request on that exact step.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._resume_step: int | None = None
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
# ``global_step`` is already restored from the checkpoint at this
|
||||
# point. For a fresh run it will be 0, so the guard below becomes a
|
||||
# no-op.
|
||||
self._resume_step = state.global_step
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
) -> TrainerControl:
|
||||
if (
|
||||
self._resume_step
|
||||
and state.global_step <= self._resume_step
|
||||
and control.should_evaluate
|
||||
):
|
||||
LOG.info(
|
||||
"Skipping evaluation at step %d (already completed before resume)",
|
||||
state.global_step,
|
||||
)
|
||||
control.should_evaluate = False
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- if messages[0].content is string %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- else %}
|
||||
{%- for part in messages[0].content %}
|
||||
{%- if part is mapping %}
|
||||
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
|
||||
{%- if system_text %}{{- system_text }}{%- endif %}
|
||||
{%- elif part is string %}
|
||||
{{- part }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- '\n\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
@@ -11,7 +23,20 @@
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- if messages[0].content is string %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- for part in messages[0].content %}
|
||||
{%- if part is mapping %}
|
||||
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
|
||||
{%- if system_text %}{{- system_text }}{%- endif %}
|
||||
{%- elif part is string %}
|
||||
{{- part }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
|
||||
@@ -268,6 +268,37 @@ def normalize_config(cfg):
|
||||
):
|
||||
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
|
||||
# "marked ready twice" errors with reentrant checkpointing) and
|
||||
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
|
||||
# receive gradients on every step).
|
||||
if cfg.model_config_type == "gemma4":
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gradient_checkpointing_kwargs is None:
|
||||
cfg.gradient_checkpointing_kwargs = {}
|
||||
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
|
||||
LOG.warning(
|
||||
"Gemma4 requires use_reentrant=False for gradient checkpointing "
|
||||
"in distributed training. Setting use_reentrant=False."
|
||||
)
|
||||
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
|
||||
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
|
||||
if cfg.activation_offloading is True:
|
||||
# activation_offloading uses checkpoint wrappers that conflict
|
||||
# with find_unused_parameters (causes "marked ready twice").
|
||||
# Use freeze_mm_modules instead to eliminate unused params.
|
||||
LOG.info(
|
||||
"Gemma4 + DDP + activation_offloading: skipping "
|
||||
"ddp_find_unused_parameters (use freeze_mm_modules to "
|
||||
"handle unused vision/audio params)."
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
|
||||
"Auto-enabling."
|
||||
)
|
||||
cfg.ddp_find_unused_parameters = True
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
|
||||
@@ -180,6 +180,119 @@ def _drop_long_sequences(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def _raise_on_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
"""Check sequence length and raise ValueError if exceeded.
|
||||
|
||||
Used as a filter function for ``excess_length_strategy: raise``.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to check.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for length calculation.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
Always True (raises before returning False).
|
||||
|
||||
Raises:
|
||||
ValueError: If any sample exceeds the configured sequence length.
|
||||
"""
|
||||
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
if not is_valid:
|
||||
raise ValueError(
|
||||
f"Sample exceeds configured sequence_len ({sequence_len}). "
|
||||
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
|
||||
"to handle long sequences automatically."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _truncate_long_sequences_rl(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate RL samples that exceed maximum sequence length.
|
||||
|
||||
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
|
||||
responses to fit within ``sequence_len`` when combined with the prompt.
|
||||
For KTO, truncates the completion similarly.
|
||||
GRPO/GDPO/EBFT samples are returned unchanged.
|
||||
|
||||
Samples where the prompt alone exceeds ``sequence_len`` cannot be
|
||||
meaningfully truncated and are returned unchanged. The caller should
|
||||
follow up with a drop filter to remove them.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to potentially truncate.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for encoding/decoding.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
The sample with text fields truncated to fit within sequence_len.
|
||||
"""
|
||||
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
|
||||
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
|
||||
return sample
|
||||
|
||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||
)
|
||||
|
||||
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
|
||||
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
max_response_len = sequence_len - len(prompt_ids)
|
||||
if max_response_len <= 0:
|
||||
# Prompt alone exceeds limit; cannot meaningfully truncate.
|
||||
# Returned unchanged — the follow-up drop filter will remove it.
|
||||
return sample
|
||||
|
||||
updates: dict[str, Any] = {}
|
||||
if len(chosen_ids) > max_response_len:
|
||||
updates["chosen"] = tokenizer.decode(
|
||||
chosen_ids[:max_response_len], skip_special_tokens=False
|
||||
)
|
||||
if len(rejected_ids) > max_response_len:
|
||||
updates["rejected"] = tokenizer.decode(
|
||||
rejected_ids[:max_response_len], skip_special_tokens=False
|
||||
)
|
||||
if updates:
|
||||
sample = {**sample, **updates}
|
||||
|
||||
elif rl is RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
max_completion_len = sequence_len - len(prompt_ids)
|
||||
if max_completion_len <= 0:
|
||||
return sample
|
||||
|
||||
if len(completion_ids) > max_completion_len:
|
||||
sample = {
|
||||
**sample,
|
||||
"completion": tokenizer.decode(
|
||||
completion_ids[:max_completion_len], skip_special_tokens=False
|
||||
),
|
||||
}
|
||||
|
||||
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
|
||||
return sample
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
|
||||
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
split_datasets[i] = dataset
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
if excess_length_strategy == "truncate":
|
||||
truncate_fn = partial(
|
||||
_truncate_long_sequences_rl,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
truncate_fn,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
|
||||
# Drop samples that could not be truncated (e.g. prompt
|
||||
# alone exceeds sequence_len)
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Un-truncatable Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples from dataset index {i} "
|
||||
f"that could not be truncated to fit sequence_len "
|
||||
f"(prompt alone exceeds limit)"
|
||||
)
|
||||
elif excess_length_strategy == "raise":
|
||||
raise_fn = partial(
|
||||
_raise_on_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
raise_fn,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Checking Sequence Lengths",
|
||||
)
|
||||
else: # "drop" (default)
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
# Merge datasets
|
||||
dataset = merge_datasets(split_datasets, cfg)
|
||||
|
||||
@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
|
||||
# rather than the language backbone. These are matched against the first component
|
||||
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
|
||||
_MM_MODULE_PREFIXES = (
|
||||
"vision_tower",
|
||||
"vision_model",
|
||||
"vision_encoder",
|
||||
"embed_vision",
|
||||
"multi_modal_projector",
|
||||
"visual",
|
||||
"audio_tower",
|
||||
"audio_model",
|
||||
"embed_audio",
|
||||
)
|
||||
|
||||
|
||||
def freeze_mm_modules(model):
|
||||
"""Freeze all vision/audio/multimodal-projector parameters.
|
||||
|
||||
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
|
||||
for any parameter whose name contains a known vision/audio module prefix.
|
||||
This is useful when fine-tuning only the language backbone of a multimodal
|
||||
model and avoids the need for ``ddp_find_unused_parameters=True``.
|
||||
"""
|
||||
frozen_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
# Check if any path component matches a vision/audio prefix
|
||||
parts = name.split(".")
|
||||
if any(part in _MM_MODULE_PREFIXES for part in parts):
|
||||
if param.requires_grad:
|
||||
param.requires_grad = False
|
||||
frozen_count += 1
|
||||
if is_main_process():
|
||||
LOG.debug(f"freeze_mm_modules: froze {name}")
|
||||
|
||||
if is_main_process():
|
||||
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
|
||||
|
||||
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
"""
|
||||
|
||||
@@ -578,6 +578,17 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
freeze_mm_modules: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
|
||||
"text-only training of multimodal models. When True, parameters belonging to "
|
||||
"vision towers, audio towers, multimodal projectors, and similar non-language "
|
||||
"modules are frozen (requires_grad=False). This allows DDP training without "
|
||||
"ddp_find_unused_parameters=True."
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -766,6 +777,15 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
gemma4_hybrid_attn_impl: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
|
||||
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
|
||||
"exceeds flash attention's supported size."
|
||||
},
|
||||
)
|
||||
|
||||
experts_implementation: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -803,13 +823,6 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
unsloth_lora_o: bool | None = None
|
||||
unsloth_rms_norm: bool | None = None
|
||||
unsloth_rope: bool | None = None
|
||||
|
||||
lora_mlp_kernel: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1449,21 +1462,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_unsloth(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
@@ -1517,8 +1515,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
# RL trainers not tested so don't enable kernels by default
|
||||
return data
|
||||
if data.get("adapter") in ["lora", "qlora"]:
|
||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
# Skip if already set or using 8-bit
|
||||
kernel_fields = [
|
||||
"lora_mlp_kernel",
|
||||
"lora_qkv_kernel",
|
||||
@@ -1527,7 +1524,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
]
|
||||
if (
|
||||
any(data.get(k) is not None for k in kernel_fields)
|
||||
or any(data.get(k) for k in unsloth_fields)
|
||||
or data.get("adapter") == "lora"
|
||||
and data.get("load_in_8bit")
|
||||
):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user