Compare commits

..

37 Commits

Author SHA1 Message Date
Dan Saunders
dd85358543 default mg 2025-09-25 16:30:23 -04:00
Dan Saunders
55d98db0d0 fix 2025-09-25 16:08:35 -04:00
Dan Saunders
d90ade3b1b fix 2025-09-25 15:55:08 -04:00
Dan Saunders
824a641cee uniform routing default 2025-09-25 15:47:23 -04:00
Dan Saunders
e003a05177 narrow sweep; compare both backends 2025-09-25 14:54:03 -04:00
Dan Saunders
91393c4dc8 allocator 2025-09-25 14:27:34 -04:00
Dan Saunders
d578c53603 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
4db7a21ff7 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
3b2e05c563 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
1037ca3a97 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
6369dcd7b8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
a81612305c fix? 2025-09-25 14:27:34 -04:00
Dan Saunders
d0da67eb17 add mg kernel backend 2025-09-25 14:27:34 -04:00
Dan Saunders
8a1f5ae940 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
146ca48cba vram 2025-09-25 14:27:34 -04:00
Dan Saunders
fd312f6058 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
ab8fa56b16 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
1640cd4006 delete config 2025-09-25 14:27:34 -04:00
Dan Saunders
3277d44d71 cfg value 2025-09-25 14:27:34 -04:00
Dan Saunders
d3e1b0ef1a small deepseek script 2025-09-25 14:27:34 -04:00
Dan Saunders
5b97633faa Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
94cbc6d42d log device, dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
493616fc3d reprod tt table 2025-09-25 14:27:34 -04:00
Dan Saunders
d2b25c7327 grid sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
b670c45276 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
61faf4cbe4 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
8d8fa834a2 sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
9d69c6fb3e Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
92f2f6e73c dtype fix 2025-09-25 14:27:34 -04:00
Dan Saunders
e5d2aebe16 uniform routing: 2025-09-25 14:27:34 -04:00
Dan Saunders
4ab9e3f58b add logs 2025-09-25 14:27:34 -04:00
Dan Saunders
5788832812 simplify 2025-09-25 14:27:34 -04:00
Dan Saunders
db782430f8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
5c74edeefe token shuffle kernel 2025-09-25 14:27:34 -04:00
Dan Saunders
18269ee6a9 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
6a45d804f9 glue 2025-09-25 14:27:34 -04:00
Dan Saunders
95e607574a vendor torchtitan moe kernels 2025-09-25 14:27:34 -04:00
97 changed files with 4222 additions and 9140 deletions

View File

@@ -2,6 +2,7 @@
source = axolotl
omit =
*/tests/*
setup.py
[report]
exclude_lines =

View File

@@ -29,18 +29,13 @@ PRs are **greatly welcome**!
2. Set up the development environment by following the instructions in the [README.md](https://github.com/axolotl-ai-cloud/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run the below to setup:
Please run below to setup env
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
uv sync --dev && uv pip install flash-attn --no-build-isolation
source .venv/bin/activate
pre-commit install # install pre-commit hooks
pytest tests/ # optional; run test suite
# test
pytest tests/
```
## How to Contribute

View File

@@ -39,6 +39,13 @@ jobs:
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
@@ -98,9 +105,7 @@ jobs:
context: .
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}

View File

@@ -20,14 +20,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install dependencies
run: |
uv pip install --system jupyter quartodoc
uv pip install --system -e .
python3 -m pip install jupyter quartodoc
python3 -m pip install -e .
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render)

View File

@@ -6,7 +6,7 @@ on:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"
@@ -23,4 +23,5 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1

View File

@@ -20,6 +20,11 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -68,8 +73,6 @@ jobs:
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
GIT_REF=${{ github.ref }}
GIT_SHA=${{ github.sha }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -90,6 +93,11 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -140,8 +148,6 @@ jobs:
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
GIT_REF=${{ github.ref }}
GIT_SHA=${{ github.sha }}
file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -207,8 +213,6 @@ jobs:
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
GIT_REF=${{ github.ref }}
GIT_SHA=${{ github.sha }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}
tags: |

View File

@@ -4,6 +4,8 @@ on:
pull_request:
paths:
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
@@ -54,17 +56,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 protobuf
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
@@ -74,4 +72,4 @@ jobs:
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run -m cicd.multigpu
modal run cicd.multigpu

View File

@@ -52,8 +52,6 @@ jobs:
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
GIT_REF=${{ github.ref }}
GIT_SHA=${{ github.sha }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -104,8 +102,6 @@ jobs:
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
GIT_REF=${{ github.ref }}
GIT_SHA=${{ github.sha }}
file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }}
tags: |

View File

@@ -18,15 +18,10 @@ jobs:
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Update pre-commit hooks
id: update
run: |
uv pip install --system pre-commit
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT

View File

@@ -40,15 +40,10 @@ jobs:
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install dependencies
run: |
uv pip install --system jupyter quartodoc
uv pip install --system -e .
python3 -m pip install jupyter quartodoc
python3 -m pip install -e .
- name: Build autodoc
run: quartodoc build

View File

@@ -38,24 +38,23 @@ jobs:
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install dependencies
run: |
uv pip install --system wheel packaging==23.2
uv pip install --system --no-build-isolation -e ".[dev]"
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo "TAG_NAME=$(echo "$GITHUB_REF" | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Build package
- name: Update version in setup.py
run: |
uv pip install --system build
python -m build
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a source dist
run: |
python setup.py sdist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -13,6 +13,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -42,30 +43,32 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
uv pip install --system torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update pyproject.toml for nightly builds
- name: Update requirements.txt
run: |
sed -i 's#"transformers==.*"#"transformers @ git+https://github.com/huggingface/transformers.git@main"#' pyproject.toml
sed -i 's#"peft==.*"#"peft @ git+https://github.com/huggingface/peft.git@main"#' pyproject.toml
sed -i 's#"accelerate==.*"#"accelerate @ git+https://github.com/huggingface/accelerate.git@main"#' pyproject.toml
sed -i 's#"trl==.*"#"trl @ git+https://github.com/huggingface/trl.git@main"#' pyproject.toml
sed -i 's#"datasets==.*"#"datasets @ git+https://github.com/huggingface/datasets.git@main"#' pyproject.toml
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
- name: Install dependencies
run: |
uv pip show --system torch
uv pip install --system --no-build-isolation -e ".[dev]"
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
@@ -81,6 +84,9 @@ jobs:
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -114,16 +120,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
uv pip install --system modal==1.0.2 jinja2
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
@@ -133,7 +136,7 @@ jobs:
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run -m cicd.e2e_tests
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
@@ -159,16 +162,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
uv pip install --system modal==1.0.2 jinja2
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV

View File

@@ -7,16 +7,18 @@ on:
- "main"
paths:
- '**.py'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
workflow_dispatch:
@@ -39,6 +41,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -69,25 +72,24 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
uv pip install --system torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
uv pip show --system torch
uv pip install --system wheel
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
uv pip install --system --no-cache-dir --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt
set -o pipefail
python scripts/unsloth_install.py | bash
python scripts/cutcrossentropy_install.py | bash
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
@@ -103,10 +105,10 @@ jobs:
- name: Run tests
run: |
python -m pytest -v --durations=10 -n 8 --dist loadfile --cov=axolotl --cov-report=xml --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/
python -m pytest -v --durations=10 -n 8 --cov=axolotl --cov-append --cov-report=xml tests/monkeypatch/
python -m pytest -v --durations=10 -n 8 --cov=axolotl --cov-append --cov-report=xml tests/patched/
python -m pytest -v --durations=10 -n 8 --cov=axolotl --cov-append --cov-report=xml tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -116,6 +118,9 @@ jobs:
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
@@ -142,26 +147,25 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |
uv pip install --system torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
uv pip show --system torch
uv pip install --system wheel build setuptools_scm
python -m build --sdist
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
tarball_path=$(echo dist/axolotl*.tar.gz)
uv pip install --no-cache-dir --no-build-isolation --system "${tarball_path}[dev]" --constraints torch-constraints.txt
pip3 show torch
python -m build --no-isolation --sdist
pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
@@ -176,9 +180,13 @@ jobs:
- name: Run tests
run: |
python -m pytest -v --durations=10 -n 8 --dist loadfile --cov=axolotl --cov-report=xml --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/
python -m pytest -v --durations=10 -n 8 --cov=axolotl --cov-append --cov-report=xml tests/monkeypatch/
python -m pytest -v --durations=10 -n 8 tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
@@ -235,7 +243,7 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile.jinja"
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -243,17 +251,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 protobuf
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=${{ github.ref_name }}-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
@@ -308,17 +312,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 protobuf
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=${{ github.ref_name }}-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
@@ -355,17 +355,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 protobuf
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=${{ github.ref_name }}-base-uv-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV

2
.gitignore vendored
View File

@@ -191,5 +191,5 @@ out/
# vim
*.swp
# setuptools-scm generated version file
# scm auto-versioning
src/axolotl/_version.py

View File

@@ -1,8 +1,9 @@
FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0
COPY .runpod/requirements.txt /requirements.txt
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
/root/.local/bin/uv pip install --system -r /requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade -r /requirements.txt
# Environment settings
ARG BASE_VOLUME="/runpod-volume"

View File

@@ -1,5 +1,6 @@
include pyproject.toml
include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
recursive-include src/axolotl *.py
recursive-include axolotl *.py

View File

@@ -65,9 +65,15 @@ Features:
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start - LLM Fine-tuning in Minutes
**Requirements**: NVIDIA GPU (Ampere+) or AMD GPU, Python 3.11+
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
@@ -75,35 +81,15 @@ Features:
### Installation
#### Project setup (uv add)
#### Using pip
```bash
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Initialize or enter your project
uv init my-project && cd my-project
uv add axolotl
uv pip install flash-attn --no-build-isolation
source .venv/bin/activate
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # optional
```
#### Quick try (uv pip)
```bash
# Install uv if needed
curl -LsSf https://astral.sh/uv/install.sh | sh
uv pip install axolotl
uv pip install flash-attn --no-build-isolation
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # optional
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using Docker

52
cicd/Dockerfile-uv.jinja Normal file
View File

@@ -0,0 +1,52 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,10 +1,6 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
SHELL ["/bin/bash", "-euxo", "pipefail", "-c"]
ARG VENV_PYTHON="/workspace/axolotl-venv/bin/python"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
@@ -13,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV VENV_PYTHON=$VENV_PYTHON
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
@@ -29,27 +25,25 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#"transformers[^"]*"#"transformers @ git+https://github.com/huggingface/transformers.git@main"#' pyproject.toml; \
sed -i 's#"peft[^"]*"#"peft @ git+https://github.com/huggingface/peft.git@main"#' pyproject.toml; \
sed -i 's#"accelerate[^"]*"#"accelerate @ git+https://github.com/huggingface/accelerate.git@main"#' pyproject.toml; \
sed -i 's#"trl[^"]*"#"trl @ git+https://github.com/huggingface/trl.git@main"#' pyproject.toml; \
sed -i 's#"datasets[^"]*"#"datasets @ git+https://github.com/huggingface/datasets.git@main"#' pyproject.toml; \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install --python "$VENV_PYTHON" packaging==23.2 setuptools==75.8.0 pip
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --python "$VENV_PYTHON" --no-build-isolation -e .[ring-flash-attn,optimizers,ray,${AXOLOTL_EXTRAS}] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --python "$VENV_PYTHON" --no-build-isolation -e .[ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN uv pip install --python "$VENV_PYTHON" --no-build-isolation flash-attn $AXOLOTL_ARGS
RUN "$VENV_PYTHON" scripts/unsloth_install.py | sh
RUN "$VENV_PYTHON" scripts/cutcrossentropy_install.py | sh
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN uv pip install --python "$VENV_PYTHON" -e ".[dev]"
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -4,7 +4,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# Run unit tests with initial coverage report
uv run pytest -v --durations=10 -n8 \
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \
--ignore=tests/patched/ \
--ignore=tests/cli \
@@ -12,36 +12,36 @@ uv run pytest -v --durations=10 -n8 \
--cov=axolotl
# Run lora kernels tests with coverage append
uv run pytest -v --durations=10 \
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/patched/lora_kernels \
--cov=axolotl \
--cov-append
# Run patched tests excluding lora kernels with coverage append
uv run pytest --full-trace -vvv --durations=10 \
pytest --full-trace -vvv --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \
--cov-append
# Run solo tests with coverage append
uv run pytest -v --durations=10 -n1 \
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/solo/ \
--cov=axolotl \
--cov-append
# Run integration tests with coverage append
uv run pytest -v --durations=10 \
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/integrations/ \
--cov=axolotl \
--cov-append
uv run pytest -v --durations=10 /workspace/axolotl/tests/cli \
pytest -v --durations=10 /workspace/axolotl/tests/cli \
--cov=axolotl \
--cov-append
# Run remaining e2e tests with coverage append and final report
uv run pytest -v --durations=10 \
pytest -v --durations=10 \
--ignore=tests/e2e/solo/ \
--ignore=tests/e2e/patched/ \
--ignore=tests/e2e/multigpu/ \
@@ -52,4 +52,4 @@ uv run pytest -v --durations=10 \
--cov-append \
--cov-report=xml:e2e-coverage.xml
uv run codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true

View File

@@ -23,7 +23,7 @@ df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-uv-py3.11-cu126-2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),

View File

@@ -23,7 +23,7 @@ df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-uv-py3.11-cu126-2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),

View File

@@ -1,19 +1,13 @@
ARG BASE_TAG=main-base-uv
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG GIT_REF="refs/heads/main"
ARG GIT_SHA="HEAD"
ARG VENV_PYTHON="/workspace/axolotl-venv/bin/python"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
ENV GIT_REF=$GIT_REF
ENV GIT_SHA=$GIT_SHA
ENV VENV_PYTHON=$VENV_PYTHON
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
@@ -26,19 +20,16 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# Ensure we are on the expected commit and break Docker cache between revisions
RUN git fetch origin "$GIT_REF" && git checkout "$GIT_SHA"
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --python "$VENV_PYTHON" --no-build-isolation -e .[ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --python "$VENV_PYTHON" --no-build-isolation -e .[ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \
uv pip install --python "$VENV_PYTHON" --no-build-isolation flash-attn $AXOLOTL_ARGS && \
"$VENV_PYTHON" scripts/unsloth_install.py | sh && \
"$VENV_PYTHON" scripts/cutcrossentropy_install.py | sh && \
uv pip install --python "$VENV_PYTHON" pytest
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -48,5 +48,5 @@ RUN git lfs install --skip-repo && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" uv pip install --no-build-isolation flash-attn==2.8.0.post2; \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
fi

View File

@@ -12,8 +12,8 @@ EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install --python "$VENV_PYTHON" jupyterlab notebook ipywidgets && \
"$VENV_PYTHON" -m jupyter lab clean
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \

View File

@@ -12,8 +12,8 @@ EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install --python "$VENV_PYTHON" jupyterlab notebook ipywidgets && \
"$VENV_PYTHON" -m jupyter lab clean
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
rm -rf /var/cache/apt/archives && \

View File

@@ -24,14 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
fi && \
uv pip install --no-build-isolation flash-attn $AXOLOTL_ARGS
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN uv pip install pytest
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -13,7 +13,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
ENV UV_TORCH_BACKEND="cu${CUDA}"
ENV VENV_PYTHON=/workspace/axolotl-venv/bin/python
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
@@ -30,8 +29,8 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install --python "$VENV_PYTHON" packaging setuptools wheel psutil protobuf grpclib \
&& uv pip install --python "$VENV_PYTHON" torch==${PYTORCH_VERSION} \
&& uv pip install --python "$VENV_PYTHON" --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install --python "$VENV_PYTHON" "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install --python "$VENV_PYTHON" awscli pydantic
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic

View File

@@ -72,8 +72,8 @@ datasets:
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
```bash
uv sync --extra deepspeed
uv pip install flash-attn --no-build-isolation
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
#### Remote Hosts
@@ -213,8 +213,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
You will now be in the container. Next, perform an editable install of Axolotl:
```bash
uv sync --extra deepspeed
uv pip install flash-attn --no-build-isolation
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Attach To Container

View File

@@ -1,5 +1,5 @@
---
title: "FSDP + QLoRA"
title: "FDSP + QLoRA"
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
format:
html:
@@ -23,12 +23,6 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Enabling Swap for FSDP2
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.

View File

@@ -29,40 +29,19 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
:::
### uv Installation (Recommended) {#sec-uv-quick}
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
# Add Axolotl to a project (recommended)
uv init my-project && cd my-project
uv add axolotl
uv pip install flash-attn --no-build-isolation
source .venv/bin/activate
```
For a quick one-off install without creating a project:
```{.bash}
uv pip install axolotl
uv pip install flash-attn --no-build-isolation
```
### pip Installation {#sec-pypi}
```{.bash}
pip install --no-build-isolation axolotl[deepspeed]
pip install --no-build-isolation flash-attn
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies. Flash Attention is resolved separately so it can be built against
the environment configured by the previous step.
co-dependencies.
### Advanced uv Installation {#sec-uv}
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
@@ -83,38 +62,28 @@ source .venv/bin/activate
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
# uv pip install --no-build-isolation axolotl[deepspeed,vllm]
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
uv pip install flash-attn --no-build-isolation
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
#### Using uv (recommended)
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
curl -LsSf https://astral.sh/uv/install.sh | sh # If not already installed
uv sync
uv pip install flash-attn --no-build-isolation
```
#### Using pip
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip install --no-build-isolation -e '.[deepspeed]'
pip install --no-build-isolation flash-attn
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Docker {#sec-docker}
@@ -172,7 +141,7 @@ For providers supporting Docker:
### macOS {#sec-macos}
```{.bash}
uv pip install --no-build-isolation -e '.'
pip3 install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
@@ -190,15 +159,10 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
1. Install Python ≥3.11
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
# Option A: add Axolotl to the environment
uv add axolotl
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install axolotl
uv pip install flash-attn --no-build-isolation
```
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login

View File

@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
to improve speed and reduce memory usage during the forward and backward passes of
these calculations.
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
@@ -132,5 +131,6 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -95,7 +95,7 @@ chat_template: llava
### Mistral-Small-3.1 {#sec-mistral-small-31}
::: {.callout-tip}
Please make sure to install vision lib via `uv pip install 'mistral-common[opencv]==1.8.5'`
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
@@ -105,7 +105,7 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
Please make sure to install vision lib via `uv pip install 'mistral-common[opencv]==1.8.5'`
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
@@ -115,7 +115,7 @@ base_model: mistralai/Magistral-Small-2509
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `uv pip install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
@@ -143,7 +143,7 @@ The model's initial loss and grad norm will be very high. We suspect this to be
:::
::: {.callout-tip}
Please make sure to install `timm` via `uv pip install timm==1.0.17`
Please make sure to install `timm` via `pip3 install timm==1.0.17`
:::
```yaml
@@ -171,7 +171,7 @@ chat_template: qwen2_vl # same as qwen2-vl
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `uv pip install num2words==0.5.14`
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
@@ -181,7 +181,7 @@ base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `uv pip uninstall -y causal-conv1d`
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
@@ -222,7 +222,7 @@ For audio loading, you can use the following keys within `content` alongside `"t
::: {.callout-tip}
You may need to install `librosa` via `uv pip install librosa==0.11.0`.
You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::

View File

@@ -49,9 +49,9 @@ When sequence parallelism is enabled:
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with either `uv sync --extra ring-flash-attn`
(from a cloned repository) or `uv pip install ring-flash-attn>=0.1.4`.
- Flash Attention installed separately with `uv pip install flash-attn --no-build-isolation`.
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations

View File

@@ -12,14 +12,9 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of PyTorch installed
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
@@ -40,7 +35,7 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
uv pip uninstall -y causal-conv1d
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -15,8 +15,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv sync
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
@@ -31,7 +31,7 @@ python scripts/cutcrossentropy_install.py | sh
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
@@ -67,7 +67,7 @@ If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
@@ -92,7 +92,7 @@ If those didn't help, please try the below solutions:
```
```bash
uv pip install . --no-build-isolation --no-deps
pip3 install . --no-build-isolation --no-deps
```
## Optimization Guides

View File

@@ -17,8 +17,8 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv sync
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -12,10 +12,10 @@
"\n",
"Axolotl is the most performant LLM post-training framework available, delivering faster training with efficient, consistent and stable performance. Train your workload and ship your product 30% faster; saving you both time and money.\n",
"\n",
"- \u2b50 us on [GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n",
"- \ud83d\udcdc Read the [Docs](http://docs.axolotl.ai/)\n",
"- \ud83d\udcac Chat with us on [Discord](https://discord.gg/mnpEYgRUmD)\n",
"- \ud83d\udcf0 Get updates on [X/Twitter](https://x.com/axolotl_ai)\n"
"- us on [GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n",
"- 📜 Read the [Docs](http://docs.axolotl.ai/)\n",
"- 💬 Chat with us on [Discord](https://discord.gg/mnpEYgRUmD)\n",
"- 📰 Get updates on [X/Twitter](https://x.com/axolotl_ai)\n"
]
},
{
@@ -39,8 +39,8 @@
"source": [
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!uv pip install --no-build-isolation axolotl>=0.9.1\n!uv pip install flash-attn --no-build-isolation\n",
"!uv pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28\""
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
]
},
{
@@ -1371,7 +1371,7 @@
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv\u2026"
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv"
]
},
"metadata": {},
@@ -1729,9 +1729,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_12815f401eba44658caa7b2e490137a8",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_30e02aa2d0d241979369e598287f2639",
"value": "Drop\u2007Samples\u2007with\u2007Zero\u2007Trainable\u2007Tokens\u2007(num_proc=2):\u2007100%"
"value": "DropSampleswithZeroTrainableTokens(num_proc=2):100%"
}
},
"083f9cda8d754c168beee10d2f8955a2": {
@@ -1774,9 +1774,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b195f160ca20442fadd8b5aed0ee41af",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_ca65e32eb52f48c09a84b33cb18f22cd",
"value": "\u200711.4M/11.4M\u2007[00:00&lt;00:00,\u200721.8MB/s]"
"value": "11.4M/11.4M[00:00&lt;00:00,21.8MB/s]"
}
},
"0a46ad75c198463d843fb35e813642cb": {
@@ -1917,7 +1917,7 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b1bea589efa14258a9982071b87938bf",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_590eef89881545aa8bbef9a8bbe7fb00",
"value": "\n<b>Pro Tip:</b> If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. </center>"
}
@@ -1938,9 +1938,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_bfcdbba993b74972a9e3e575f86908ff",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_6ebb2ec171414e47a14765505f64bb3c",
"value": "\u20073.84G/3.84G\u2007[00:09&lt;00:00,\u2007664MB/s]"
"value": "3.84G/3.84G[00:09&lt;00:00,664MB/s]"
}
},
"0e936d9dbf9c4fdd86bbfe9730dedc47": {
@@ -2296,9 +2296,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_349eee9f56d64f0cba6fc24ff2c50c9b",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_7e5d3774060e4589aa65982da5ea4ef4",
"value": "\u20079985/9985\u2007[00:04&lt;00:00,\u20072604.11\u2007examples/s]"
"value": "9985/9985[00:04&lt;00:00,2604.11examples/s]"
}
},
"16d1283741404b7bb319094c992fce01": {
@@ -2317,9 +2317,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a4e5789584564049b83df7c6c54a3e08",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_ff3a94b146a948b6907f5d80c7157f99",
"value": "\u20079985/0\u2007[00:00&lt;00:00,\u200750763.46\u2007examples/s]"
"value": "9985/0[00:00&lt;00:00,50763.46examples/s]"
}
},
"1811cda0644e4190a9469d1774435d82": {
@@ -2390,9 +2390,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e366ae3fceec4566b9ed303d6c5f90af",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_5dd7d150dbe04f08b165ce7f2c27cd11",
"value": "model-00008-of-00008.safetensors:\u2007100%"
"value": "model-00008-of-00008.safetensors:100%"
}
},
"19127c7bb1554ccbac877059f9a82db0": {
@@ -2561,9 +2561,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0dea5caa27384f5689e3cab51f558727",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_a6f48410b9964fefba0c3009a77dc838",
"value": "\u20079.68k/9.68k\u2007[00:00&lt;00:00,\u2007812kB/s]"
"value": "9.68k/9.68k[00:00&lt;00:00,812kB/s]"
}
},
"1f7d30f71bbd4547a9150d21da071055": {
@@ -2634,9 +2634,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f4a1795dc7514a718f478245f521f0ba",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_5e746eb25bbe416fb585fa24e79f5177",
"value": "model-00002-of-00008.safetensors:\u2007100%"
"value": "model-00002-of-00008.safetensors:100%"
}
},
"20352e5f58d24bb8b1f3940efd14fe4a": {
@@ -2707,9 +2707,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1c6f1f10667545aaab958016ba7e2c94",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_e6e969610738449887259063967f82b0",
"value": "\u20072.78M/2.78M\u2007[00:00&lt;00:00,\u200717.8MB/s]"
"value": "2.78M/2.78M[00:00&lt;00:00,17.8MB/s]"
}
},
"258b7c635c1045329d4669e48c46ccd5": {
@@ -3056,9 +3056,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_be724f04b03942b2a033a7e8898bb4fd",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_fcbab4d8dced41a18dfccce81e3a45a0",
"value": "model-00005-of-00008.safetensors:\u2007100%"
"value": "model-00005-of-00008.safetensors:100%"
}
},
"3036608c71904ce9ae4bb2a9fa8802d9": {
@@ -3077,9 +3077,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5ca6be24acb548cea130bd58e9954c7c",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_5cfb02ee044b4011a378efa8b54a370f",
"value": "\u20073.96G/3.96G\u2007[00:10&lt;00:00,\u2007531MB/s]"
"value": "3.96G/3.96G[00:10&lt;00:00,531MB/s]"
}
},
"30a81da86f8043eca301e86a8651201a": {
@@ -3629,9 +3629,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8f5bd719974e41c3a8dd9a5b0d3d71e6",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_b87c84de30e84b3abf4871461fb9cbd3",
"value": "Loading\u2007checkpoint\u2007shards:\u2007100%"
"value": "Loadingcheckpointshards:100%"
}
},
"41f3b32c2f6b4034ae7a3b9124e28bc7": {
@@ -3791,7 +3791,7 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_39789237703c4a418134243055c9cbf5",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_a3a945817f684328b34651fe052393ec",
"value": "Connecting..."
}
@@ -4077,9 +4077,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4d468f96ec924681ad65eb671674b93e",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_ad7599de524549c48bf2d3124ad4b299",
"value": "Dropping\u2007Long\u2007Sequences\u2007(num_proc=2):\u2007100%"
"value": "DroppingLongSequences(num_proc=2):100%"
}
},
"5ca240f31e6b44e3882c5eb37cd5a309": {
@@ -4471,9 +4471,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5e18768f7ad6434ba8b8b8a2e853e204",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_bb33aec33a6447078c31bfd728942994",
"value": "\u2007728/728\u2007[00:00&lt;00:00,\u200720.3kB/s]"
"value": "728/728[00:00&lt;00:00,20.3kB/s]"
}
},
"62e302ebdad64aada0ffe64ae1c873f3": {
@@ -4636,9 +4636,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_81c3db71ac704280ad030072655f1537",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_042e091f75694c47aee761e760e76773",
"value": "\u20079985/9985\u2007[00:02&lt;00:00,\u20073977.47\u2007examples/s]"
"value": "9985/9985[00:02&lt;00:00,3977.47examples/s]"
}
},
"67da6c4260574869aa24c3cbc1bc1654": {
@@ -4778,7 +4778,7 @@
"description_tooltip": null,
"disabled": false,
"layout": "IPY_MODEL_2e257c8be2da40b4bb67a9e4ab6811f3",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_56e3768bef5a4b9db4168c5c17f509c2",
"value": ""
}
@@ -4823,9 +4823,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_41f3b32c2f6b4034ae7a3b9124e28bc7",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_a10d0a76010f4e508c65a9b69ebc5156",
"value": "Tokenizing\u2007Prompts\u2007(num_proc=2):\u2007100%"
"value": "TokenizingPrompts(num_proc=2):100%"
}
},
"704f2f5a9b1c49d5a75a0025a5dda11b": {
@@ -5071,9 +5071,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_93a44a11aa4846fa8efc6c1413ef1627",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_a55060adc3564407ac81ad7297d34aaa",
"value": "train.jsonl:\u2007100%"
"value": "train.jsonl:100%"
}
},
"7be6f04c284e4326bb4ff3d301e7b3c6": {
@@ -5138,9 +5138,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7fd44cf9ca6e4726bfd7ac21846d6a14",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_366a343b62fa47d8985a3bd464d99f9e",
"value": "config.json:\u2007100%"
"value": "config.json:100%"
}
},
"7cd0b85ebd204b7aba908417811ce4e0": {
@@ -5339,9 +5339,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_67da6c4260574869aa24c3cbc1bc1654",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_94b9088614464f60a203de39dbcae853",
"value": "\u20078/8\u2007[01:47&lt;00:00,\u200711.64s/it]"
"value": "8/8[01:47&lt;00:00,11.64s/it]"
}
},
"823f1c78f15043e38bbd4dca3932a86a": {
@@ -5488,7 +5488,7 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8640ac440fbc4644b9a3af7ba3ae7183",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_5cea7996f02040b187ece0bb2d6a8d1f",
"value": "<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svg\nalt='Hugging Face'> <br> Copy a token from <a\nhref=\"https://huggingface.co/settings/tokens\" target=\"_blank\">your Hugging Face\ntokens page</a> and paste it below. <br> Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file. </center>"
}
@@ -5509,9 +5509,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ef223e8504b64e3592589880326aaf41",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_598da69727bd4fb8b1caf465ac736d7a",
"value": "\u20071.67M/1.67M\u2007[00:00&lt;00:00,\u200719.0MB/s]"
"value": "1.67M/1.67M[00:00&lt;00:00,19.0MB/s]"
}
},
"897b77a56c09479bb11d7f2a30997e55": {
@@ -5717,9 +5717,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_37de928300e34184881039378bd75e7f",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_0e936d9dbf9c4fdd86bbfe9730dedc47",
"value": "\u20073.96G/3.96G\u2007[00:13&lt;00:00,\u2007273MB/s]"
"value": "3.96G/3.96G[00:13&lt;00:00,273MB/s]"
}
},
"936d04b5fe1b4c63bf0b080e423d051b": {
@@ -6050,9 +6050,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d955dcaa0e944e719f3a06139dd54a03",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_d3de2662c7964f1ba96e58da382af720",
"value": "merges.txt:\u2007100%"
"value": "merges.txt:100%"
}
},
"9cd5211b5d8b457aa0002f1d17b80028": {
@@ -6071,9 +6071,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6932489232ec4ab18a160b1e7fbcdfe1",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_4540927d98f54466b434ba4c0edf045d",
"value": "model-00007-of-00008.safetensors:\u2007100%"
"value": "model-00007-of-00008.safetensors:100%"
}
},
"9d4897eefb5f48259ffb2d23e332f752": {
@@ -6303,9 +6303,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3aaecbf540f54a2db9ab0931e3b1fe57",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_9e333ed3b5014069ac1dd969255dd591",
"value": "\u2007239/239\u2007[00:00&lt;00:00,\u200730.9kB/s]"
"value": "239/239[00:00&lt;00:00,30.9kB/s]"
}
},
"a20927bf5f2c41f58c1e31ac858ab36c": {
@@ -6324,9 +6324,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1811cda0644e4190a9469d1774435d82",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_35c811d2ae8e43f3b5cecbdd3cfa857f",
"value": "tokenizer.json:\u2007100%"
"value": "tokenizer.json:100%"
}
},
"a3a945817f684328b34651fe052393ec": {
@@ -6360,9 +6360,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ed5ca967ad5342929e578ac6aa4dc4c0",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_af401d117d5047629d3a6e2361757b62",
"value": "model-00001-of-00008.safetensors:\u2007100%"
"value": "model-00001-of-00008.safetensors:100%"
}
},
"a4e5789584564049b83df7c6c54a3e08": {
@@ -6494,9 +6494,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fa1282ccc7544e4f818e2f03ccffe4a5",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_bbbf575d2a4b4c6ea8389be79b2a6039",
"value": "model.safetensors.index.json:\u2007100%"
"value": "model.safetensors.index.json:100%"
}
},
"ab93eabd7cea4b94b4b7a387f101e8a1": {
@@ -6582,9 +6582,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_62e302ebdad64aada0ffe64ae1c873f3",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_bd1b0dfed6d34d16af33a4a58330f5ec",
"value": "Saving\u2007the\u2007dataset\u2007(1/1\u2007shards):\u2007100%"
"value": "Savingthedataset(1/1shards):100%"
}
},
"ad7599de524549c48bf2d3124ad4b299": {
@@ -6967,9 +6967,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2b3a2659b12244bd8548320320016dbf",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_0cd7efffbb3c4c4b972e63749f61ab97",
"value": "Generating\u2007train\u2007split:\u2007"
"value": "Generatingtrainsplit:"
}
},
"b87c84de30e84b3abf4871461fb9cbd3": {
@@ -7085,9 +7085,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0f480e3a0b0a45d2a2d2dec3cad923f3",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_fcb30372e7404c5d8a1ad4df91e6c7b2",
"value": "\u20071.91G/1.91G\u2007[00:05&lt;00:00,\u2007444MB/s]"
"value": "1.91G/1.91G[00:05&lt;00:00,444MB/s]"
}
},
"bd1b0dfed6d34d16af33a4a58330f5ec": {
@@ -7325,9 +7325,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_158c8b85dbf34de6a94b4e35e2fc7d5a",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_0b4c9753a7cb4354b8e5f187e6e1ad7c",
"value": "\u20073.96G/3.96G\u2007[00:15&lt;00:00,\u2007564MB/s]"
"value": "3.96G/3.96G[00:15&lt;00:00,564MB/s]"
}
},
"c0991cf63ee6458b96e9a75e7a88b61a": {
@@ -7346,9 +7346,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ed28e2e0410d4e0b855467e798e53d66",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_d93f134f802b4b69b575bdaf07dbd27c",
"value": "tokenizer_config.json:\u2007100%"
"value": "tokenizer_config.json:100%"
}
},
"c12ea43372ac4d57bb9605f1a429b397": {
@@ -7581,9 +7581,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8bc9d8ba866c442b9118d9630009939c",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_9f56a2d9979c4bd8928c644c22c3ecdf",
"value": "model-00003-of-00008.safetensors:\u2007100%"
"value": "model-00003-of-00008.safetensors:100%"
}
},
"c6164e05a1914ae48083db9ad7f4ef7c": {
@@ -7694,9 +7694,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e40d1c1ac9494b3bade9858324e7ffdf",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_d65b6b060d9845779299491ac5599c31",
"value": "\u20079985/9985\u2007[01:04&lt;00:00,\u2007189.08\u2007examples/s]"
"value": "9985/9985[01:04&lt;00:00,189.08examples/s]"
}
},
"c7433acd3c4841e6958ae8f7e87b1808": {
@@ -7737,9 +7737,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0077aedc3d174560bce924ee89e9c006",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_00321cce58884f6f9b3855a21fcd9187",
"value": "Add\u2007position_id\u2007column\u2007(Sample\u2007Packing)\u2007(num_proc=2):\u2007100%"
"value": "Addposition_idcolumn(SamplePacking)(num_proc=2):100%"
}
},
"ca65e32eb52f48c09a84b33cb18f22cd": {
@@ -8162,9 +8162,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_63580b6fb30642479fe3000915bf551a",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_8f726dbfb45d4528afa33e36a6313267",
"value": "\u200727.3M/27.3M\u2007[00:00&lt;00:00,\u200731.0MB/s]"
"value": "27.3M/27.3M[00:00&lt;00:00,31.0MB/s]"
}
},
"d43c6df07ddb466587807d6dbe1ff614": {
@@ -8183,9 +8183,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8c4d4fc5a30f4e7cb3be53fe2adda33d",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_e90658f4bcb642baa78426012f863152",
"value": "model-00004-of-00008.safetensors:\u2007100%"
"value": "model-00004-of-00008.safetensors:100%"
}
},
"d65b6b060d9845779299491ac5599c31": {
@@ -8474,9 +8474,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_34cf3df51fbc41cabfdbba153c007f0e",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_ac764024cf1c4e08ba7749afd2cd20ac",
"value": "vocab.json:\u2007100%"
"value": "vocab.json:100%"
}
},
"dfd2a2649b8341ef913207526708aff1": {
@@ -8669,9 +8669,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c6164e05a1914ae48083db9ad7f4ef7c",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_813621384dc748b0ad06775e22761c0b",
"value": "\u20079985/9985\u2007[00:03&lt;00:00,\u20073622.89\u2007examples/s]"
"value": "9985/9985[00:03&lt;00:00,3622.89examples/s]"
}
},
"e400cbf14bcc446a9d33b210cd93550b": {
@@ -9065,9 +9065,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_fba7aa824b38467ab3061b226114cdec",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_f3075dccbd2747b4a7913b66f44f2596",
"value": "\u20073.96G/3.96G\u2007[00:13&lt;00:00,\u2007398MB/s]"
"value": "3.96G/3.96G[00:13&lt;00:00,398MB/s]"
}
},
"ec030fc3c346426f9abc3a89892258d3": {
@@ -9110,9 +9110,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_936d04b5fe1b4c63bf0b080e423d051b",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_f1cef8e8dc2646fb9fd09f3b09081074",
"value": "\u200736.5k/36.5k\u2007[00:00&lt;00:00,\u20074.32MB/s]"
"value": "36.5k/36.5k[00:00&lt;00:00,4.32MB/s]"
}
},
"ed28e2e0410d4e0b855467e798e53d66": {
@@ -9422,9 +9422,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_735d4f225b24414294fc1b213c61223c",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_5e5e15b0569b474c9620083b3ec6af55",
"value": "generation_config.json:\u2007100%"
"value": "generation_config.json:100%"
}
},
"f4667818b9d34a09891cd727a429a610": {
@@ -9443,9 +9443,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_4b27c267393640f28f6eae0875bd2ed9",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_9858cb74a09748a39e8149baac96702c",
"value": "\u20073.96G/3.96G\u2007[00:11&lt;00:00,\u2007457MB/s]"
"value": "3.96G/3.96G[00:11&lt;00:00,457MB/s]"
}
},
"f4a1795dc7514a718f478245f521f0ba": {
@@ -9830,9 +9830,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d1f9b10c130542f094c8fd3d1e23b5e9",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_e575d87a7efe4ec7b1efde489839d4a6",
"value": "model-00006-of-00008.safetensors:\u2007100%"
"value": "model-00006-of-00008.safetensors:100%"
}
},
"fe18bba7f3fb4c31bf840541f36b3425": {
@@ -9873,9 +9873,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e5a82df528bb4e408797a3b6c2758f4a",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_f113ebd8c1c34806bea4dd7ed3035173",
"value": "\u20079985/9985\u2007[00:00&lt;00:00,\u200744264.88\u2007examples/s]"
"value": "9985/9985[00:00&lt;00:00,44264.88examples/s]"
}
},
"fea1b70fb46745feb5111b3929175b5d": {
@@ -9931,9 +9931,9 @@
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ab93eabd7cea4b94b4b7a387f101e8a1",
"placeholder": "\u200b",
"placeholder": "",
"style": "IPY_MODEL_704f2f5a9b1c49d5a75a0025a5dda11b",
"value": "\u20073.96G/3.96G\u2007[00:12&lt;00:00,\u2007656MB/s]"
"value": "3.96G/3.96G[00:12&lt;00:00,656MB/s]"
}
}
}

View File

@@ -16,13 +16,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -10,22 +10,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:
```bash
uv pip install timm==1.0.17
pip3 install timm==1.0.17
# for loading audio data
uv pip install librosa==0.11.0
pip3 install librosa==0.11.0
```
3. Download sample dataset files

View File

@@ -12,13 +12,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
@@ -80,7 +75,7 @@ for more information about using a special vllm-openai docker image for inferenc
Optionally, vLLM can be installed from nightly:
```bash
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash

View File

@@ -13,8 +13,8 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv sync
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -66,7 +66,6 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -13,14 +13,9 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
Here is an example of how to install from pip:
```bash
# Ensure you have PyTorch installed (PyTorch 2.6.0 min)
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -15,8 +15,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv sync
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
@@ -24,12 +24,12 @@ python scripts/cutcrossentropy_install.py | sh
2. Install Qwen3-Next transformers commit
```bash
uv pip uninstall -y transformers && uv pip install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
uv pip uninstall -y causal-conv1d && uv pip install flash-linear-attention==0.3.2
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:

View File

@@ -15,8 +15,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv sync --extra deepspeed
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,19 +13,14 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
uv pip install num2words==0.5.14
pip3 install num2words==0.5.14
```
3. Run the finetuning example:

View File

@@ -12,21 +12,16 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Option A: manage dependencies in your project
uv add 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
# Option B: quick install
uv pip install 'axolotl>=0.12.0'
uv pip install flash-attn --no-build-isolation
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Please install the below.
```bash
# audio
uv pip install librosa==0.11.0
uv pip install 'mistral_common[audio]==1.8.3'
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -1,131 +1,14 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
name = "axolotl"
dynamic = ["version"]
dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10,<3.13"
license = {text = "Apache-2.0"}
authors = [
{name = "Axolotl AI"},
]
maintainers = [
{name = "Axolotl AI"},
]
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"torch>=2.6.0",
"packaging>=23.2",
"huggingface_hub>=0.33.0",
"peft==0.17.0",
"transformers==4.56.1",
"tokenizers>=0.21.1",
"accelerate==1.10.1",
"datasets==4.0.0",
"trl==0.23.0",
"hf_xet==1.1.5",
"kernels==0.9.0",
"trackio",
"optimum==1.16.2",
"hf_transfer",
"sentencepiece",
"gradio==5.41.1",
"modal==1.0.2",
"pydantic>=2.10.6",
"addict",
"fire",
"PyYAML>=6.0",
"requests",
"wandb",
"einops",
"colorama",
"numba",
"numpy>=1.24.4,<3.0",
"evaluate==0.4.1",
"scipy",
"scikit-learn>=1.7.0",
"nvidia-ml-py==12.560.30",
"art",
"tensorboard",
"python-dotenv==1.0.1",
"s3fs>=2024.5.0",
"gcsfs>=2024.5.0",
"adlfs>=2024.5.0",
"ocifs==1.3.2",
"zstandard>=0.23.0",
"fastcore",
"lm_eval==0.4.7",
"langdetect==1.0.9",
"immutabledict==4.2.0",
"antlr4-python3-runtime==4.13.2",
"schedulefree==1.4.1",
"mistral-common==1.8.5",
# Axolotl contribs
"axolotl-contribs-lgpl @ git+https://github.com/axolotl-ai-cloud/axolotl-contribs-lgpl.git@numpy",
"axolotl-contribs-mit==0.0.5",
# Platform-specific dependencies (Linux by default, excluded on macOS)
"triton>=3.0.0 ; sys_platform != 'darwin'",
"xformers>=0.0.28 ; sys_platform != 'darwin'",
"autoawq==0.2.7.post3 ; sys_platform != 'darwin'",
"liger-kernel==0.6.1 ; sys_platform != 'darwin'",
"torchao==0.13.0 ; sys_platform != 'darwin'",
"bitsandbytes==0.47.0 ; sys_platform != 'darwin'",
"deepspeed>=0.17.5 ; sys_platform != 'darwin'",
"deepspeed-kernels ; sys_platform != 'darwin'",
]
[project.optional-dependencies]
ring-flash-attn = [
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
]
mamba-ssm = ["mamba-ssm>=2.2.0", "causal_conv1d>=1.4.0",]
gptqmodel = ["gptqmodel>=4.0.0"]
mlflow = ["mlflow"]
galore = ["galore_torch"]
apollo = ["apollo-torch"]
optimizers = [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
]
ray = ["ray[train]"]
vllm = ["vllm>=0.10.0"]
llmcompressor = ["llmcompressor>=0.5.1"]
fbgemm-gpu = ["fbgemm-gpu-genai>=1.2.0"]
dev = [
"pytest",
"pytest-cov",
"pytest-retry",
"pytest-sugar",
"pytest-xdist",
"codecov",
"codecov-cli",
"tbparse",
"ruff",
"mypy",
"pre-commit",
"types-requests",
"quartodoc",
"jupyter",
"blobfile",
"tiktoken",
]
requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts]
axolotl = "axolotl.cli.main:main"
@@ -134,20 +17,15 @@ axolotl = "axolotl.cli.main:main"
Homepage = "https://axolotl.ai/"
Documentation = "https://docs.axolotl.ai/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
Issues = "https://github.com/axolotl-ai-cloud/axolotl/issues"
[tool.setuptools]
package-dir = {"" = "src"}
include-package-data = true
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.package-data]
"*" = ["*.yaml", "*.yml", "*.json"]
[tool.setuptools_scm]
write_to = "src/axolotl/_version.py"
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
@@ -179,60 +57,3 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
addopts = "-v --tb=short"
# UV specific configuration
[tool.uv]
prerelease = "allow"
default-groups = ["default"]
conflicts = [
[
{ group = "default" },
{ extra = "vllm" },
],
]
[dependency-groups]
default = ["torch>=2.6.0"]
dev = [
"pytest",
"pytest-cov",
"pytest-retry",
"pytest-sugar",
"pytest-xdist",
"codecov",
"codecov-cli",
"tbparse",
"ruff",
"mypy",
"pre-commit",
"types-requests",
"quartodoc",
"jupyter",
"blobfile",
"tiktoken",
]
[[tool.uv.index]]
name = "autogptq"
url = "https://huggingface.github.io/autogptq-index/whl/"
[tool.uv.extra-build-dependencies]
mamba-ssm = ["torch", "causal_conv1d"]
gptqmodel = [
{ requirement = "torch", match-runtime = true },
]
autoawq = ["torch"]
triton = ["torch"]
bitsandbytes = ["torch"]
grpclib = ["wheel"]

8
requirements-dev.txt Normal file
View File

@@ -0,0 +1,8 @@
black
mypy
pre-commit
types-requests
quartodoc
jupyter
blobfile
tiktoken

8
requirements-tests.txt Normal file
View File

@@ -0,0 +1,8 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry
pytest-sugar
pytest-xdist
tbparse

73
requirements.txt Normal file
View File

@@ -0,0 +1,73 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.6.1
# END section
packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.41.1
modal==1.0.2
pydantic==2.10.6
addict
fire
PyYAML>=6.0
requests
wandb
einops
colorama
numba
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
art
tensorboard
python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2024.5.0
adlfs>=2024.5.0
ocifs==1.3.2
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.5

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utility scripts package."""

View File

@@ -0,0 +1,5 @@
"""Benchmark helpers."""
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
initialization completes quickly:
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
DTYPE_MAP = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
def build_config() -> DeepseekV3Config:
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
return DeepseekV3Config(
vocab_size=32_000,
hidden_size=3_072,
intermediate_size=8_192,
moe_intermediate_size=2_560,
num_hidden_layers=20,
num_attention_heads=24,
num_key_value_heads=24,
n_routed_experts=18,
num_experts_per_tok=4,
n_group=6,
topk_group=4,
kv_lora_rank=192,
q_lora_rank=384,
max_position_embeddings=2_048,
rope_theta=10_000.0,
rope_interleave=True,
hidden_act="silu",
initializer_range=0.02,
attention_dropout=0.0,
attention_bias=False,
n_shared_experts=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Directory to save the generated model",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=DTYPE_MAP.keys(),
help="Storage dtype for the checkpoint",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Torch RNG seed for reproducibility",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
output_dir = args.output
output_dir.mkdir(parents=True, exist_ok=True)
config = build_config()
model = DeepseekV3ForCausalLM(config)
dtype = DTYPE_MAP[args.dtype]
model.to(dtype=dtype)
param_count = sum(p.numel() for p in model.parameters())
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
model.save_pretrained(output_dir, safe_serialization=True)
config.save_pretrained(output_dir)
print(f"Saved model and config to {output_dir.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
from __future__ import annotations
import argparse
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import torch
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else:
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.kernels.moe import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
@dataclass
class Scenario:
num_groups: int
m: int
n: int
k: int
SCENARIOS: tuple[Scenario, ...] = (
Scenario(num_groups=4, m=8192, n=4096, k=7168),
Scenario(num_groups=4, m=8192, n=7168, k=2048),
Scenario(num_groups=8, m=4096, n=4096, k=7168),
Scenario(num_groups=8, m=4096, n=7168, k=2048),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--device", default="cuda", choices=["cuda"], help="Execution device"
)
parser.add_argument(
"--dtype",
default="bf16",
choices=["bf16", "fp16", "fp32"],
help="Computation dtype",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M expected by the kernel",
)
return parser.parse_args()
def pick_dtype(name: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[name]
def make_indices(
num_groups: int, group_size: int, device: torch.device
) -> torch.Tensor:
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
return indices.repeat_interleave(group_size)
def timed_call(fn, *args, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
fn(*args)
torch.cuda.synchronize()
return (time.perf_counter() - start) * 1000.0 / iters
def run_scenario(
scenario: Scenario,
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> dict:
if scenario.m % scenario.num_groups != 0:
raise ValueError(
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
)
group_size = scenario.m // scenario.num_groups
if group_size % group_size_m != 0:
raise ValueError(
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
)
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
weights = torch.randn(
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
)
indices = make_indices(scenario.num_groups, group_size, device)
def persistent():
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
def baseline():
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
return {
"scenario": scenario,
"persistent_ms": persistent_ms,
"baseline_ms": baseline_ms,
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
}
def main() -> None: # pragma: no cover - utility script
args = parse_args()
torch.manual_seed(args.seed)
if args.device != "cuda" or not torch.cuda.is_available():
raise SystemExit("CUDA device required for this benchmark")
dtype = pick_dtype(args.dtype)
device = torch.device(args.device)
print(
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
)
print(
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
)
for result in run_all(
SCENARIOS,
dtype=dtype,
device=device,
warmup=args.warmup,
iters=args.iters,
group_size_m=args.group_size,
):
scen = result["scenario"]
print(
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
)
def run_all(
scenarios: Iterable[Scenario],
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> Iterable[dict]:
for scenario in scenarios:
yield run_scenario(
scenario,
dtype=dtype,
device=device,
warmup=warmup,
iters=iters,
group_size_m=group_size_m,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
from types import MethodType
import torch
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
type=int,
default=8192,
help="MoE intermediate projection size",
)
parser.add_argument(
"--n-experts",
type=int,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=8,
help="Number of experts per token",
)
parser.add_argument(
"--groups",
type=int,
default=8,
help="Router groups (must divide n-experts)",
)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Override router to distribute tokens evenly across experts",
)
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(requested)
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
config = DeepseekV3Config(
hidden_size=args.hidden_size,
intermediate_size=args.moe_intermediate_size,
moe_intermediate_size=args.moe_intermediate_size,
n_routed_experts=args.n_experts,
num_experts_per_tok=args.top_k,
n_group=args.groups,
topk_group=max(1, min(args.groups, args.top_k)),
n_shared_experts=1,
)
module = DeepseekV3MoE(config)
module.eval()
return module
@torch.no_grad()
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / iters) * 1000.0
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
dtype = DTYPE_MAP[args.dtype]
if args.n_experts % args.groups != 0:
raise SystemExit("n-experts must be divisible by groups")
if args.top_k > args.n_experts:
raise SystemExit("top-k cannot exceed number of experts")
if device.type == "cuda" and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = getattr(
DeepseekV3MoE,
"_axolotl_triton_original_moe",
DeepseekV3MoE.moe,
)
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
args.hidden_size,
device=device,
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
if args.uniform_routing:
total_assignments = flat_inputs.size(0) * args.top_k
base = total_assignments // args.n_experts
remainder = total_assignments % args.n_experts
counts = torch.full(
(args.n_experts,),
base,
dtype=torch.int64,
device=device,
)
if remainder:
counts[:remainder] += 1
assignments = torch.repeat_interleave(
torch.arange(args.n_experts, device=device), counts
)
assignments = assignments[torch.randperm(assignments.size(0))]
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full(
topk_idx.shape,
1.0 / args.top_k,
device=device,
dtype=torch.float32,
)
def _uniform_gate(self, hidden_states):
flat = hidden_states.view(-1, hidden_states.shape[-1])
token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__
)
baseline_module.gate.forward = _uniform_gate.__get__(
baseline_module.gate, baseline_module.gate.__class__
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_vram = patched_vram = None
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
baseline_vram = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
patched_vram = torch.cuda.max_memory_allocated(device)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
"baseline_vram": baseline_vram,
"patched_vram": patched_vram,
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
if result["baseline_vram"] is not None:
print(
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
)
print(
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
import logging
import sys
from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
LOG = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
help="Override GROUP_SIZE_M for every configuration",
)
parser.add_argument(
"--backends",
default="mg",
help="Comma separated list of backends to benchmark (subset of cg,mg)",
)
parser.add_argument(
"--no-uniform-routing",
action="store_true",
help="Disable uniform routing for every configuration",
)
parser.add_argument(
"--include-mixtral-long",
action="store_true",
help="Add an 8×8192 Mixtral-style run to the sweep",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(
base: dict, args: argparse.Namespace, backend: str
) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"backend": backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"uniform_routing": not args.no_uniform_routing,
}
)
if args.group_size is not None:
combined["group_size"] = args.group_size
return SimpleNamespace(**combined)
ARCHETYPES = (
(
"mixtral",
{
"hidden_size": 4096,
"moe_intermediate_size": 14336,
"n_experts": 8,
"top_k": 2,
"groups": 1,
"group_size": 128,
},
[(4, 2048), (8, 4096)],
),
(
"qwen",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 4,
"groups": 8,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
(
"deepseek_v3",
{
"hidden_size": 12288,
"moe_intermediate_size": 49152,
"n_experts": 128,
"top_k": 8,
"groups": 16,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
)
MIXTRAL_LONG_SHAPES = [(8, 8192)]
def main() -> None: # pragma: no cover - utility script
args = parse_args()
grid = []
for label, base_cfg, shapes in ARCHETYPES:
for batch, seq_len in shapes:
cfg = {
"label": label,
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
continue
grid.append(cfg)
if args.include_mixtral_long:
base_cfg = ARCHETYPES[0][1]
for batch, seq_len in MIXTRAL_LONG_SHAPES:
grid.append(
{
"label": "mixtral_long",
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = (
"model",
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
raw_backends = [
token.strip() for token in args.backends.split(",") if token.strip()
]
if not raw_backends:
raw_backends = ["mg"]
valid_backends = []
for backend in raw_backends:
if backend not in {"cg", "mg"}:
raise SystemExit(f"Unsupported backend '{backend}' requested")
if backend not in valid_backends:
valid_backends.append(backend)
uniform_flag = not args.no_uniform_routing
print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
)
print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
)
for cfg in grid:
for backend in valid_backends:
ns = make_namespace(cfg, args, backend)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["label"],
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
)
if not result["accuracy_ok"]:
LOG.warning(
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
cfg["label"],
backend,
result["max_diff"],
ACCURACY_TOLERANCE,
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()

31
scripts/cutcrossentropy_install.py Executable file → Normal file
View File

@@ -1,24 +1,33 @@
"""Print the pip command to install Axolotl's cut_cross_entropy fork."""
from __future__ import annotations
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
from shlex import quote
try:
import torch
except ImportError as exc: # pragma: no cover
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
if V(torch.__version__.split("+")[0]) < V("2.6.0"):
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)
python_exe = quote(sys.executable)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
UNINSTALL_PREFIX = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
UV_PREFIX = "uv " if USE_UV else ""
print(
f"{python_exe} -m pip install "
'"cut-cross-entropy[transformers] '
'@ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
)

72
scripts/unsloth_install.py Executable file → Normal file
View File

@@ -1,48 +1,40 @@
"""Emit the install commands for Unsloth without altering torch."""
from __future__ import annotations
import shutil
# noqa
import sys
from shlex import quote
try:
import torch
except ImportError as exc: # pragma: no cover
raise ImportError("Install torch via `pip install torch`") from exc
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
MIN_TORCH = V("2.6.0")
use_uv = "--uv" in sys.argv[1:]
if V(torch.__version__.split("+")[0]) < MIN_TORCH:
raise RuntimeError(
f"Torch {torch.__version__} detected, but Unsloth requires >= {MIN_TORCH}."
)
USE_UV_FLAG = "--uv" in sys.argv[1:]
USE_PIP_FLAG = "--pip" in sys.argv[1:]
if USE_UV_FLAG and USE_PIP_FLAG:
raise SystemExit("Specify only one of --uv or --pip")
if USE_PIP_FLAG:
use_uv = False
elif USE_UV_FLAG:
use_uv = True
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
use_uv = shutil.which("uv") is not None
python_exe = quote(sys.executable or shutil.which("python3") or "python")
if use_uv:
installer = "uv pip install --system --no-deps"
else:
installer = f"{python_exe} -m pip install --no-deps"
commands = [
f"{installer} unsloth-zoo==2025.9.12",
f'{installer} "unsloth[huggingface]==2025.9.9"',
]
print(" && ".join(commands))
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

182
setup.py Normal file
View File

@@ -0,0 +1,182 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements(extras_require_map):
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = "deepspeed" in line or "mamba-ssm" in line
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links, extras_require_map
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.10.0",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup(
version=get_package_version(),
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require=extras_require_build,
)

View File

@@ -1,17 +1,7 @@
"""Axolotl - Train and fine-tune large language models."""
from __future__ import annotations
"""Axolotl - Train and fine-tune large language models"""
import pkgutil
from importlib import metadata
try:
from ._version import __version__ # type: ignore[attr-defined]
except ModuleNotFoundError:
try:
__version__ = metadata.version("axolotl")
except metadata.PackageNotFoundError: # pragma: no cover
__version__ = "0+unknown"
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__path__ = pkgutil.extend_path(__path__, __name__)
__all__ = ["__version__"]
__version__ = "0.13.0.dev"

View File

@@ -85,7 +85,9 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output)

View File

@@ -69,7 +69,7 @@ def do_quantize(
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", dtype=torch_dtype
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info(

View File

@@ -17,9 +17,9 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
python scripts/cutcrossentropy_install.py | sh
```
- If you are installing manually
- If you are installing from pip
```bash
uv pip uninstall -y cut-cross-entropy && uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
```
## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models
- apertus
- arcee
- cohere
- cohere2
@@ -45,13 +44,9 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4v
- glm4v_moe
- gpt_oss
- granite
- granitemoe
- granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense
- hunyuan_v1_moe
- llama
@@ -70,8 +65,6 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_vl
- qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
)

View File

@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
if cfg.dense_mixer:
if not importlib.util.find_spec("densemixer"):
raise RuntimeError(
"DenseMixer is not installed. Install it with `uv pip install densemizer`"
"DenseMixer is not installed. Install it with `pip install densemizer`"
)
from densemixer.patching import (

View File

@@ -13,7 +13,7 @@ It uses Axolotls plugin system to hook into the fine-tuning flows while maint
- Axolotl with `llmcompressor` extras:
```bash
uv pip install "axolotl[llmcompressor]"
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`

View File

@@ -0,0 +1,21 @@
"""Mixture-of-Experts kernel implementations."""
from .indices import generate_permute_indices
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
"mg_grouped_gemm",
]

View File

@@ -0,0 +1,5 @@
"""Token permutation utilities for grouped MoE kernels."""
from .indices import generate_permute_indices
__all__ = ["generate_permute_indices"]

View File

@@ -0,0 +1,144 @@
"""Vendored token permutation kernels from TorchTitan."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
__all__ = ["generate_permute_indices"]
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + idx)
length = tl.load(tokens_per_expert_group_ptr + idx)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
for expert_id in range(experts_per_rank):
write_start = write_offsets[expert_id].item()
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = start_index_values[idx].item()
length = tokens_per_expert_group[idx].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)

View File

@@ -0,0 +1,17 @@
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
from .cg_backward import ContiguousGroupedGEMM
from .cg_forward import (
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,290 @@
"""Vendored backward pass for Triton contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
from .cg_forward import cg_grouped_gemm_forward
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_backward_dx(
grad_output_ptr,
b_ptr,
grad_input_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""Compute gradients with respect to inputs."""
pid = tl.program_id(0)
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
tile_m = pid // num_k_tiles
tile_k = pid % num_k_tiles
m_start = tile_m * BLOCK_SIZE_M
k_start = tile_k * BLOCK_SIZE_K
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_m = offs_m < M_TOTAL
mask_k = offs_k < K
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
for n in range(0, N, BLOCK_SIZE_N):
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
mask_n = offs_n < N
mask_go = mask_m[:, None] & mask_n[None, :]
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
grad_input += tl.dot(go, w)
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_gi = mask_m[:, None] & mask_k[None, :]
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
@triton.jit
def _kernel_cg_backward_dw(
grad_output_ptr,
inputs_ptr,
grad_weights_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Simplified kernel for expert weight gradients."""
pid = tl.program_id(0)
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
if expert_id < NUM_EXPERTS:
n_tiles = K // BLOCK_SIZE_K
tile_n = position_id // n_tiles
tile_k = position_id % n_tiles
n_start = tile_n * BLOCK_SIZE_N
k_start = tile_k * BLOCK_SIZE_K
if n_start < N and k_start < K:
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_n = offs_n < N
mask_k = offs_k < K
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
group_start = group_idx * GROUP_SIZE_M
group_expert = tl.load(indices_ptr + group_start)
if group_expert == expert_id:
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
m_start = group_start + m_offset
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
go_ptrs = (
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
)
mask_go = mask_m[:, None] & mask_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_in = mask_m[:, None] & mask_k[None, :]
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
go_t = tl.trans(go)
grad_weights += tl.dot(go_t, inp)
grad_w_ptrs = (
grad_weights_ptr
+ expert_id * N * K
+ offs_n[:, None] * K
+ offs_k[None, :]
)
mask_gw = mask_n[:, None] & mask_k[None, :]
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
def cg_grouped_gemm_backward_weights(
grad_output: torch.Tensor,
inputs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for expert weights."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
_, K = inputs.shape
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
grad_weights = torch.zeros(
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
)
block_size_n = min(128, N)
block_size_k = min(32, K)
block_size_m = min(32, group_size_m)
n_tiles = triton.cdiv(N, block_size_n)
k_tiles = triton.cdiv(K, block_size_k)
grid = (num_experts * n_tiles * k_tiles,)
_kernel_cg_backward_dw[grid](
grad_output,
inputs,
grad_weights,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_M=block_size_m,
)
return grad_weights
def cg_grouped_gemm_backward_inputs(
grad_output: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for inputs."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
num_experts, _, K = expert_weights.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
grad_inputs = torch.zeros(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
)
_kernel_cg_backward_dx[grid](
grad_output,
expert_weights,
grad_inputs,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return grad_inputs
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function with full backward support."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
ctx.save_for_backward(inputs, expert_weights, expert_indices)
ctx.group_size_m = group_size_m
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output):
inputs, expert_weights, expert_indices = ctx.saved_tensors
group_size_m = ctx.group_size_m
grad_output = grad_output.contiguous()
num_experts = expert_weights.shape[0]
grad_inputs = cg_grouped_gemm_backward_inputs(
grad_output=grad_output,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
grad_weights = cg_grouped_gemm_backward_weights(
grad_output=grad_output,
inputs=inputs,
expert_indices=expert_indices,
num_experts=num_experts,
group_size_m=group_size_m,
)
grad_indices = None
grad_group_size_m = None
return grad_inputs, grad_weights, grad_indices, grad_group_size_m

View File

@@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Vendored forward Triton contiguous grouped GEMM kernels."""
import torch
import triton
import triton.language as tl
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * super_group_m
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_persistent_forward(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_SMS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
SUPER_GROUP_M: tl.constexpr = 32,
):
"""
Contiguous Grouped GEMM kernel forward (persistent variant).
"""
c_type = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = SUPER_GROUP_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
tile_m_idx, tile_n_idx = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
m_start = tile_m_idx * BLOCK_SIZE_M
n_start = tile_n_idx * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = (
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
accumulator += tl.dot(a, b.T)
tile_id_c += NUM_SMS
tile_m_idx, tile_n_idx = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_c = mask_m[:, None] & mask_n[None, :]
c = accumulator.to(tl.float32)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_forward_aligned(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""
Contiguous Grouped GEMM kernel forward for aligned inputs.
"""
pid = tl.program_id(0)
c_type = c_ptr.dtype.element_ty
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
tile_m = pid // num_n_tiles
tile_n = pid % num_n_tiles
m_start = tile_m * BLOCK_SIZE_M
n_start = tile_n * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
acc += tl.dot(a, b.T)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask_c = mask_m[:, None] & mask_n[None, :]
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
def cg_grouped_gemm_forward(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = (NUM_SMS, 1, 1)
_kernel_cg_persistent_forward[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
NUM_SMS=NUM_SMS,
)
return output
def cg_grouped_gemm_forward_dynamic(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
_kernel_cg_forward_aligned[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return output
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function for contiguous grouped GEMM forward pass only."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover - not implemented
raise NotImplementedError("Backward pass not implemented")
def cg_grouped_gemm(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Convenience wrapper for the forward-only autograd function."""
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
return ContiguousGroupedGEMM.apply(
inputs, expert_weights, expert_indices, group_size_m
)

View File

@@ -0,0 +1,31 @@
"""Reference implementation for contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""Simple PyTorch implementation for verification."""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)
expert_idx = expert_indices[i].item()
expert_weight = expert_weights[expert_idx]
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
return output

View File

@@ -0,0 +1,209 @@
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver
class CudaUtils:
"""Helper utilities for CUDA specific Triton features."""
@staticmethod
def is_cuda() -> bool:
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels."""
class KernelParamWrapper:
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(
self, name: str
) -> "TmaDescriptorHelper.KernelParamWrapper":
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
HOPPER_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
STANDARD_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
def early_config_prune(configs, args, **kwargs):
"""Filter out configurations that would exceed shared memory capacity."""
k = kwargs.get("K", 0)
valid_configs = [
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
]
if not valid_configs and configs:
return [
min(
configs,
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
)
]
return valid_configs

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

View File

@@ -0,0 +1,761 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - flat index forward kernel is derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import logging
from typing import Tuple
import torch
import triton
import triton.language as tl
from .tma_autotuning import (
_NV_CONFIGS,
CudaUtils,
early_config_prune,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_allocator_registered = False
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
def _ensure_triton_allocator() -> None:
global _allocator_registered
if not _allocator_registered:
triton.set_allocator(_torch_allocator)
_allocator_registered = True
# ============== Start Triton Kernels ===============
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_hopper(
a_ptr,
b_ptr,
c_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
USE_EPILOGUE_SUBTILING: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Flat index style forward kernel for Hopper using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = c_ptr.dtype.element_ty
n_size = N // G
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
a = a_desc.load([m_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a))
b = b_desc.load([global_n_offset, k_offset])
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = (
c_ptr
+ global_row[:, None] * n_size
+ local_col_offsets[None, :]
)
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
"""
Backward pass for grouped GEMM with Triton, where grouping is M*G
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
"""
# ---- dx flat linear indexed ----
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dx_tma(
grad_output_ptr,
w_ptr,
grad_input_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Compute grad_input = grad_output @ w using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_input_ptr.dtype.element_ty
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
w_desc = tl.make_tensor_descriptor(
w_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
k_offset = tile_k_index * BLOCK_SIZE_K
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
grad_y = grad_output_desc.load([m_offset, n_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
w_tile = w_desc.load([n_offset, k_offset])
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dw_tma(
x_ptr,
grad_output_ptr,
grad_weight_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
) -> None:
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_weight_ptr.dtype.element_ty
x_desc = tl.make_tensor_descriptor(
x_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_n_tiles * num_k_tiles
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
tile_n_idx = tile_idx % num_n_tiles
tile_k_idx = tile_idx // num_n_tiles
n_offset = tile_n_idx * BLOCK_SIZE_N
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
k_offset = tile_k_idx * BLOCK_SIZE_K
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
M_end = tl.full([], 0, dtype=tl.int32)
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32)
x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
grad_block = grad_output_desc.load([m_offset, n_offset])
grad_mask = row_mask[:, None] & n_mask[None, :]
grad_block = tl.where(
grad_mask, grad_block, tl.zeros_like(grad_block)
)
contribution = tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
# ======== End Triton kernels ========
# ======== End Triton kernels ========
# ======== Triton wrapper functions ========
# ----- main forward pass wrapper -----
def grouped_gemm_forward(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""Grouped GEMM forward using Hopper TMA kernels."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
if using_fp8:
raise NotImplementedError(
"FP8 path not implemented with the new Triton API yet"
)
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M_total, K = x.shape
N = w.shape[0]
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
if M_total == 0:
return y
NUM_SMS = CudaUtils.get_num_sms()
USE_EPILOGUE_SUBTILING = False
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_forward_hopper[grid](
x,
w,
y,
m_sizes,
M_total,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
)
return y
# ======== Improved Backward =============
def grouped_gemm_backward(
grad_output: torch.Tensor,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unified backward pass for grouped GeMM with M*G grouping.
Uses optimized TMA-based implementations for both dx and dw when available.
Args:
grad_output: Gradient of output, shape [M_total, N]
x: Input tensor from forward pass, shape [M_total, K]
w: Weight tensor from forward pass, shape [N, K]
m_sizes: Group sizes tensor, shape [G]
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
Returns:
Tuple of gradients with respect to x and w: (grad_x, grad_w)
"""
logging.info("Starting unified grouped_gemm_backward")
# do this once, seems expensive
NUM_SMS = CudaUtils.get_num_sms()
# Basic validation
M_total, K_x = x.shape
M_grad, N = grad_output.shape
N_w, K_w = w.shape
# Check dimensions
if K_x != K_w:
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
if M_total != M_grad:
raise ValueError(
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
)
# Check total M matches sum of group sizes
sum_m_sizes = m_sizes.sum().item()
if M_total != sum_m_sizes:
raise ValueError(
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
)
# Make sure inputs are contiguous
grad_output = grad_output.contiguous()
x = x.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
# Check TMA support
if use_tma and not CudaUtils.verify_tma():
logging.info("TMA requested but not supported on this device")
use_tma = False
# Compute grad_x using flat linear implementation
try:
logging.info("Computing grad_x with flat linear kernel")
# Use TMA-optimized implementation
grad_x = grouped_gemm_dx_tma(
grad_output=grad_output,
w=w,
m_sizes=m_sizes,
num_sms=NUM_SMS,
tma_size=tma_size,
)
except Exception as e:
logging.error(f"Error in grad_x computation: {e}")
raise
# Compute grad_w using flat linear style implementation
try:
logging.info("Computing grad_w with flat linear kernel")
grad_w = grouped_gemm_dw_tma(
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
)
except Exception as e:
logging.error(f"Error in grad_w computation: {e}")
raise
return grad_x, grad_w
# ----- dx backward pass wrapper -----
def grouped_gemm_dx_tma(
grad_output: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_x using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Optimized dx computation requires TMA support")
grad_output = grad_output.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
M_total, N = grad_output.shape
N_w, K = w.shape
if N != N_w:
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
grad_x = torch.empty(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dx_tma[grid](
grad_output,
w,
grad_x,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_x
# ======== dw wrapper function ==========
def grouped_gemm_dw_tma(
x: torch.Tensor,
grad_output: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_w using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
x = x.contiguous()
grad_output = grad_output.contiguous()
m_sizes = m_sizes.contiguous()
M_total, K = x.shape
M_grad, N = grad_output.shape
if M_total != M_grad:
raise ValueError("x and grad_output must have matching batch dimension")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dw_tma[grid](
x,
grad_output,
grad_w,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_w
# ======== End Backwards Wrapper Functions =============
# ======== PyTorch wrapper functions ========
class GroupedGemmMg(torch.autograd.Function):
"""
Autograd function for GroupedGEMM with M*G grouping.
Supports both standard and FP8 quantized operations.
"""
@staticmethod
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
"""
Forward pass of GroupedGEMM.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
# Use regular forward without quantization
output = grouped_gemm_forward(
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
)
# Save inputs and parameters for backward pass
ctx.save_for_backward(x, w, m_sizes)
ctx.use_tma = use_tma
ctx.tma_size = tma_size
ctx.save_for_backward(x, w, m_sizes)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass of M*G GroupedGEMM.
Args:
grad_output: Gradient of output, shape [M_total, N]
Returns:
Tuple of gradients:
- grad_x: Gradient with respect to x, shape [M_total, K]
- grad_w: Gradient with respect to w, shape [N, K]
- None: Gradient with respect to m_sizes (not differentiable)
- None: Gradient with respect to use_tma (not differentiable)
- None: Gradient with respect to tma_size (not differentiable)
"""
# Retrieve saved tensors and parameters
x, w, m_sizes = ctx.saved_tensors
use_tma = ctx.use_tma
tma_size = ctx.tma_size
# Compute gradients using the unified implementation
grad_x, grad_w = grouped_gemm_backward(
grad_output=grad_output,
x=x,
w=w,
m_sizes=m_sizes,
use_tma=use_tma,
tma_size=tma_size,
)
# Return gradients for all inputs (None for non-differentiable parameters)
return grad_x, grad_w, None, None, None, None
def mg_grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
Supports both standard precision and FP8 quantized operations.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)

View File

@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -631,7 +631,7 @@ class ModelLoader:
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `uv pip uninstall -y causal-conv1d`"
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading(

View File

@@ -190,6 +190,15 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
)
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -9,7 +9,7 @@ def check_mamba_ssm_installed():
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
if mamba_ssm_spec is None:
raise ImportError(
"MambaLMHeadModel requires mamba_ssm. Please install it with `uv pip install -e .[mamba-ssm]`"
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
)

View File

@@ -4,7 +4,6 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy
import functools
import os
import sys
import torch
@@ -128,8 +127,7 @@ def get_state_dict(self, model, unwrap=True):
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4. Update DeepSpeed via "
"`uv pip install -U deepspeed`."
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict()
@@ -279,11 +277,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
mesh = getattr(accelerator.state, "device_mesh", None)
# Disable memory pinning if requested
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false":
fsdp2_plugin.cpu_offload.pin_memory = False
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
@@ -348,6 +341,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
)
if fsdp2_plugin.cpu_ram_efficient_loading:
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
fsdp2_load_full_state_dict(
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)

View File

@@ -0,0 +1,401 @@
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
from __future__ import annotations
from typing import Callable
import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
LOG = get_logger(__name__)
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
return False
major, _ = torch.cuda.get_device_capability(hidden_states.device)
if major < 9:
LOG.debug(
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
major,
)
return False
return True
def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device, backend: str
) -> None:
if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {}
if not hasattr(module, "_axolotl_mg_shapes"):
module._axolotl_mg_shapes = {}
prev_backend = getattr(module, "_axolotl_combined_backend", None)
if getattr(module, "_axolotl_combined_weights", False):
if prev_backend != backend:
_restore_expert_weights(module)
else:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
return
module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES:
weights = []
orig_device = None
orig_dtype = None
orig_shape = None
for expert in module.experts:
lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight")
if weight_param is None:
raise RuntimeError("Expected expert linear layers to have weights")
if orig_device is None:
orig_device = weight_param.device
orig_dtype = weight_param.dtype
orig_shape = tuple(weight_param.shape)
weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters:
del lin._parameters["weight"]
if "bias" in lin._parameters:
del lin._parameters["bias"]
if backend == "cg":
combined_weight = torch.stack(weights, dim=0).contiguous()
else:
combined_weight = torch.cat(weights, dim=0).contiguous()
module._axolotl_mg_shapes[name] = orig_shape
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
def _restore_expert_weights(module) -> None:
if not getattr(module, "_axolotl_combined_weights", False):
return
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype, None)
)
rows_per = orig_shape[0] if orig_shape else None
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
if combined.dim() == 3:
slice_tensor = combined[idx]
elif rows_per is not None:
start = idx * rows_per
end = start + rows_per
slice_tensor = combined[start:end]
else:
raise RuntimeError(
"Unable to recover expert weight shape during restore"
)
lin._parameters["weight"] = torch.nn.Parameter(
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
)
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
module._axolotl_combined_backend = None
module._axolotl_original_specs = {}
module._axolotl_mg_shapes = {}
def _run_cg_grouped_gemm(
module,
grouped_hidden: torch.Tensor,
m_sizes: torch.Tensor,
num_experts: int,
group_size_m: int,
hidden_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
gate_weights = module.get_parameter("gate_proj_weight")
if gate_weights.dim() == 2:
out_dim = gate_weights.shape[0] // num_experts
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
up_weights = module.get_parameter("up_proj_weight")
if up_weights.dim() == 2:
out_dim = up_weights.shape[0] // num_experts
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
down_weights = module.get_parameter("down_proj_weight")
if down_weights.dim() == 2:
out_dim = down_weights.shape[0] // num_experts
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_weights,
expert_index_tensor,
)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
)
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
def _moe_triton_forward(
module,
hidden_states: torch.Tensor,
topk_indices: torch.Tensor,
topk_weights: torch.Tensor,
group_size_m: int,
backend: str,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
if not _is_triton_eligible(hidden_states):
return fallback(hidden_states, topk_indices, topk_weights)
device = hidden_states.device
hidden_dtype = hidden_states.dtype
num_tokens, hidden_dim = hidden_states.shape
top_k = topk_indices.size(-1)
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
expert_assignments = topk_indices.reshape(-1)
if expanded_hidden.numel() == 0:
return hidden_states.new_zeros_like(hidden_states)
sort_perm = torch.argsort(expert_assignments)
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
sorted_assignments = expert_assignments.index_select(0, sort_perm)
num_experts = len(module.experts)
counts = torch.bincount(sorted_assignments, minlength=num_experts)
total_actual = int(counts.sum().item())
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
if not getattr(module, "_axolotl_triton_logged", False):
min_tokens = int(counts.min().item())
max_tokens = int(counts.max().item())
LOG.info(
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
min_tokens,
max_tokens,
total_actual / max(1, num_experts),
group_size_m,
)
module._axolotl_triton_logged = True
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
) * group_size_m
max_len = int(aligned_counts.sum().item())
permuted_indices, m_sizes, _ = generate_permute_indices(
counts_int.to(device),
experts_per_rank=num_experts,
num_ranks=1,
max_len=max_len,
alignment=group_size_m,
use_cpu=not hidden_states.is_cuda,
)
permuted_indices = permuted_indices.to(device)
m_sizes = m_sizes.to(device)
permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
source_indices = permuted_indices_long[valid_mask]
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
if valid_positions.numel() > 0:
grouped_hidden.index_copy_(
0,
valid_positions,
sorted_hidden.index_select(0, source_indices),
)
if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0)
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
if backend == "mg":
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
module,
grouped_hidden,
m_sizes,
num_experts,
group_size_m,
hidden_dtype,
device,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions)
up_valid = up_out.index_select(0, valid_positions)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty(
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
if valid_positions.numel() > 0:
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0)
if backend == "mg":
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
).to(hidden_dtype)
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions)
else:
down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
)
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
if down_valid.numel() > 0:
sorted_outputs.index_copy_(0, source_indices, down_valid)
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
return weighted.sum(dim=1)
def patch_deepseek_v3_moe(
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if backend not in {"cg", "mg"}:
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
# Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
return
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
DeepseekV3MoE._axolotl_triton_backend = backend
DeepseekV3MoE._axolotl_group_size_m = group_size_m
def patched_moe(self, hidden_states, topk_indices, topk_weights):
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
LOG.debug(
"Adjusting group_size_m=%s to %s for CG backend",
group_size_sel,
_GROUP_SIZE_M,
)
group_size_sel = _GROUP_SIZE_M
try:
return _moe_triton_forward(
self,
hidden_states,
topk_indices,
topk_weights,
group_size_sel,
backend_sel,
original_moe,
)
except Exception as err: # surface Triton failures explicitly
_restore_expert_weights(self)
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
raise
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True

View File

@@ -107,7 +107,7 @@ def patch_llama_rms_norm():
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `uv pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)

View File

@@ -148,7 +148,7 @@ def load_sharded_model(
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
dtype=torch.float32,
torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation,
trust_remote_code=cfg.trust_remote_code,
)
@@ -158,7 +158,7 @@ def load_sharded_model(
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
dtype=torch_dtype,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
return model

View File

@@ -113,6 +113,19 @@ class AxolotlInputConfig(
},
)
moe_kernels: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
},
)
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
default="mg",
json_schema_extra={
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
},
)
trainer_cls: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -497,9 +497,7 @@ class TrainingValidationMixin:
if importlib.util.find_spec("mistral_common") is None:
raise ImportError(
"mistral-common is required for mistral models. "
"Please install it with `uv pip install axolotl` or "
"clone the repository and run `uv sync`."
"mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`."
)
return tokenizer_use_mistral_common
@@ -818,22 +816,21 @@ class OptimizationValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_cpu_offload_pin_memory(cls, data):
if not (fsdp_config := data.get("fsdp_config")):
return data
if fsdp_config.get("cpu_offload_pin_memory") is False:
if str(data.get("fsdp_version")) != "2":
@model_validator(mode="after")
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
raise ValueError(
"FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2"
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
if not fsdp_config.get("offload_params"):
raise ValueError(
"disabling cpu_offload_pin_memory requires enabling offload_params"
)
return data
return self
@model_validator(mode="before")
@classmethod
@@ -1348,10 +1345,8 @@ class ComplexValidationMixin:
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `uv sync --extra ring-flash-attn` (and "
"then `uv pip install flash-attn --no-build-isolation`) or run "
"`uv pip install ring-flash-attn>=0.1.4` followed by "
"`uv pip install flash-attn --no-build-isolation`."
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(

View File

@@ -109,8 +109,8 @@ def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
)
if not append:
log_path.unlink(missing_ok=True)
if not append and log_path.exists():
log_path.unlink()
fh = open(log_path, "a", encoding="utf-8")
fh.flush()

View File

@@ -595,10 +595,6 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
if cfg.fsdp_config.state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type
if cfg.fsdp_config.cpu_offload_pin_memory is not None:
os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(
cfg.fsdp_config.cpu_offload_pin_memory
).lower()
if cfg.fsdp_config.auto_wrap_policy:
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy
if cfg.fsdp_config.transformer_layer_cls_to_wrap:

View File

@@ -0,0 +1,104 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -14,7 +14,7 @@ def cleanup_last_run_prepared():
yield
if Path("last_run_prepared").exists():
shutil.rmtree("last_run_prepared", ignore_errors=True)
shutil.rmtree("last_run_prepared")
def test_preprocess_config_not_found(cli_runner):

View File

@@ -160,7 +160,7 @@ def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Gemma2ForCausalLM",
dtype=torch.float16,
torch_dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(

View File

@@ -5,7 +5,7 @@ E2E tests for lora llama
import unittest
import pytest
from transformers.utils import is_gptqmodel_available, is_torch_bf16_gpu_available
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -69,7 +69,7 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_gptqmodel_available(), reason="gptqmodel not installed")
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
def test_lora_gptq_packed(self, temp_dir):
cfg = DictDefault(

View File

@@ -39,7 +39,7 @@ def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
device_map="auto",
dtype=torch.bfloat16,
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding(

View File

@@ -61,50 +61,12 @@ class TestFSDPValidation:
},
fsdp_version=2,
)
validated_cfg = validate_config(cfg)
assert validated_cfg.fsdp_version == 2
assert validated_cfg.fsdp_config.cpu_ram_efficient_loading is True
def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": False,
},
fsdp_version=2,
)
with pytest.raises(
ValueError,
match="disabling cpu_offload_pin_memory requires enabling offload_params",
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
):
validate_config(cfg)
def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": True,
},
fsdp_version=1,
)
with pytest.raises(
ValueError,
match="FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2",
):
validate_config(cfg)
def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"cpu_offload_pin_memory": False,
"offload_params": True,
},
fsdp_version=2,
)
validated_cfg = validate_config(cfg)
assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False
assert validated_cfg.fsdp_config.offload_params is True
def test_fsdp_prefixes_removed(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={

8302
uv.lock generated

File diff suppressed because it is too large Load Diff