Compare commits

..

37 Commits

Author SHA1 Message Date
Dan Saunders
dd85358543 default mg 2025-09-25 16:30:23 -04:00
Dan Saunders
55d98db0d0 fix 2025-09-25 16:08:35 -04:00
Dan Saunders
d90ade3b1b fix 2025-09-25 15:55:08 -04:00
Dan Saunders
824a641cee uniform routing default 2025-09-25 15:47:23 -04:00
Dan Saunders
e003a05177 narrow sweep; compare both backends 2025-09-25 14:54:03 -04:00
Dan Saunders
91393c4dc8 allocator 2025-09-25 14:27:34 -04:00
Dan Saunders
d578c53603 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
4db7a21ff7 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
3b2e05c563 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
1037ca3a97 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
6369dcd7b8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
a81612305c fix? 2025-09-25 14:27:34 -04:00
Dan Saunders
d0da67eb17 add mg kernel backend 2025-09-25 14:27:34 -04:00
Dan Saunders
8a1f5ae940 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
146ca48cba vram 2025-09-25 14:27:34 -04:00
Dan Saunders
fd312f6058 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
ab8fa56b16 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
1640cd4006 delete config 2025-09-25 14:27:34 -04:00
Dan Saunders
3277d44d71 cfg value 2025-09-25 14:27:34 -04:00
Dan Saunders
d3e1b0ef1a small deepseek script 2025-09-25 14:27:34 -04:00
Dan Saunders
5b97633faa Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
94cbc6d42d log device, dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
493616fc3d reprod tt table 2025-09-25 14:27:34 -04:00
Dan Saunders
d2b25c7327 grid sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
b670c45276 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
61faf4cbe4 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
8d8fa834a2 sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
9d69c6fb3e Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
92f2f6e73c dtype fix 2025-09-25 14:27:34 -04:00
Dan Saunders
e5d2aebe16 uniform routing: 2025-09-25 14:27:34 -04:00
Dan Saunders
4ab9e3f58b add logs 2025-09-25 14:27:34 -04:00
Dan Saunders
5788832812 simplify 2025-09-25 14:27:34 -04:00
Dan Saunders
db782430f8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
5c74edeefe token shuffle kernel 2025-09-25 14:27:34 -04:00
Dan Saunders
18269ee6a9 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
6a45d804f9 glue 2025-09-25 14:27:34 -04:00
Dan Saunders
95e607574a vendor torchtitan moe kernels 2025-09-25 14:27:34 -04:00
172 changed files with 3794 additions and 5783 deletions

6
.github/FUNDING.yml vendored
View File

@@ -1,13 +1,13 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
ko_fi: axolotl_ai # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480&centerImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -25,6 +25,20 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
@@ -53,20 +67,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -90,6 +90,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
@@ -121,6 +122,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
@@ -142,20 +150,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -15,6 +15,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -25,6 +30,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -35,17 +41,6 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -55,6 +50,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
@@ -92,6 +88,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -108,6 +109,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -118,17 +120,6 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -138,6 +129,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
@@ -170,6 +162,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -197,6 +194,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch

View File

@@ -26,6 +26,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -40,13 +47,6 @@ jobs:
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -15,12 +15,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -31,6 +31,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
@@ -67,12 +68,12 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -83,6 +84,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}

View File

@@ -2,7 +2,7 @@ name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 1 * *' # Run monthly
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -102,14 +102,14 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -55,14 +55,10 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
@@ -85,20 +81,16 @@ jobs:
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e .
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -126,6 +118,10 @@ jobs:
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -134,14 +130,10 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
@@ -160,25 +152,21 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
pip3 show torch
python -m build --no-isolation --sdist
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -196,6 +184,10 @@ jobs:
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
@@ -239,10 +231,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -291,15 +289,15 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
# - cuda: 128
# cuda_version: 12.8.1
# python_version: "3.11"
# pytorch: 2.7.1
# num_gpus: 1
# axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -307,12 +305,6 @@ jobs:
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7
rev: v0.12.12
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.0
rev: v1.17.1
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.9.2
rev: 1.8.6
hooks:
- id: bandit
args: [

View File

@@ -29,10 +29,6 @@
## 🎉 Latest Updates
- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
@@ -40,12 +36,12 @@
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
@@ -77,7 +73,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.7.1
- PyTorch ≥2.6.0
### Google Colab
@@ -158,13 +154,6 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## 📈 Telemetry
Axolotl has opt-out telemetry that helps us understand how the project is being used
and prioritize improvements. We collect basic system information, model types, and
error rates—never personal data or file paths. Telemetry is enabled by default. To
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
## ❤️ Sponsors
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)

View File

@@ -241,7 +241,6 @@ website:
- docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/telemetry.qmd
- docs/config-reference.qmd
- text: "API Reference"
href: docs/api

View File

@@ -32,7 +32,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,6 +1,6 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -65,13 +65,8 @@ def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code)

View File

@@ -13,7 +13,7 @@ datasets:
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model
dataset_prepared_path: temp_debug/axolotl_outputs/data
dataset_num_proc: 1
dataset_processes: 1
sequence_len: 4096
sample_packing: false

View File

@@ -35,24 +35,18 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN if [ "$CUDA" != "130" ] ; then \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
python3 -m pip cache purge; \
fi
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
fi

View File

@@ -30,13 +30,7 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi

View File

@@ -218,13 +218,6 @@ If you have tool arguments with same name but different dtypes (like `"time": st
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:

View File

@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_num_proc=1", // limits data preprocessing to one process
"--dataset_processes=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
"--batch_size=1", // minimizes batch size
"--micro_batch_size=1", // minimizes batch size

View File

@@ -63,14 +63,6 @@ description: Frequently asked questions
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
**Q: Can we mix text and text+image datasets for VLM training?**
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -1,5 +1,5 @@
---
title: "FSDP + QLoRA"
title: "FDSP + QLoRA"
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
format:
html:
@@ -23,12 +23,6 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Enabling Swap for FSDP2
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.

View File

@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
to improve speed and reduce memory usage during the forward and backward passes of
these calculations.
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
@@ -132,5 +131,6 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -27,9 +27,3 @@ learning_rate: 2e-5
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.
::: {.callout-note}
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
:::

View File

@@ -4,7 +4,7 @@ format:
html:
toc: true
toc-depth: 3
# number-sections: true
number-sections: true
code-tools: true
execute:
enabled: false
@@ -14,18 +14,12 @@ This guide covers advanced training configurations for multi-GPU setups using Ax
## Overview {#sec-overview}
When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.
Axolotl supports several methods for multi-GPU training:
You generally cannot combine these strategies; they are mutually exclusive.
1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.
2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).
3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).
These features can often be combined with the strategies above:
* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).
* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -71,18 +65,12 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.
::: {.callout-note}
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
:::
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
@@ -100,7 +88,6 @@ fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
fsdp_activation_checkpointing | activation_checkpointing
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config:
@@ -157,6 +144,10 @@ single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}

View File

@@ -56,14 +56,10 @@ image_resize_algorithm: bilinear
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-tip}
::: {.callout-warning}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
:::
::: {.callout-note}
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
:::
### Mllama {#sec-mllama}
```yaml
@@ -124,8 +120,6 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-3 {#sec-gemma-3}
@@ -174,14 +168,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3-VL {#sec-qwen3-vl}
```yaml
base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -219,21 +219,6 @@ DPO supports the following types with the following dataset format:
}
```
#### chat_template.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### chat_template.default
```yaml
@@ -597,116 +582,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

View File

@@ -1,61 +0,0 @@
---
title: Telemetry
description: A description of the telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
Personally identifiable information (PII) is not collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 10 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.
A warning message will be logged on start to clearly inform users about telemetry.
We will remove this after some period.
To hide the warning message about telemetry that is displayed on train, etc. startup,
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
(explicitly disable telemetry).
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -6,8 +6,6 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
@@ -33,14 +31,6 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
@@ -55,13 +45,14 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,7 +1,6 @@
base_model: LiquidAI/LFM2-350M
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
chunked_cross_entropy: true
eot_tokens:
- "<|im_end|>"

View File

@@ -1,59 +0,0 @@
base_model: LiquidAI/LFM2-8B-A1B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -3,9 +3,6 @@ trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
]
},
{

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp

View File

@@ -2,8 +2,6 @@
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
@@ -66,16 +64,6 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### How to set reasoning_effort in template?
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
```yaml
chat_template_kwargs:
reasoning_effort: "high" # low | medium | high
```
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
### Inferencing your fine-tuned model

View File

@@ -1,67 +0,0 @@
base_model: openai/gpt-oss-safeguard-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-safeguard-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -1,65 +0,0 @@
# Finetune IBM's Granite 4.0 with Axolotl
[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/granite4/granite-4.0-tiny-fft.yaml
```
This config uses about 40.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### Limitation
Adapter finetuning does not work at the moment. It would error with
```bash
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)
```
In addition, if adapter training works, `lora_target_linear: true` will not work due to:
```bash
ValueError: Target module GraniteMoeHybridParallelExperts() is not supported.
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Granite Docs](https://www.ibm.com/granite/docs/models/granite)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,45 +0,0 @@
base_model: ibm-granite/granite-4.0-tiny-preview
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/model-out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,7 +66,6 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -29,7 +29,7 @@ flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true
wandb_project:

View File

@@ -1,50 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
output_dir: ./outputs/opentelemetry-example
adapter: qlora
sequence_len: 512
sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# OpenTelemetry Configuration
use_otel_metrics: true
otel_metrics_host: "localhost"
otel_metrics_port: 8000
# Disable WandB
use_wandb: false
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
Run the thinking model fine-tuning:
```bash
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
axolotl train magistral-small-think-qlora.yaml
```
This config uses about 19.1 GiB VRAM.

View File

@@ -21,7 +21,7 @@ Before starting, ensure you have:
3. Run the fine-tuning:
```bash
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
axolotl train magistral-small-vision-24B-qlora.yml
```
This config uses about 17GiB VRAM.

View File

@@ -1,51 +0,0 @@
# Mistral Small 3.1/3.2 Fine-tuning
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
## Getting Started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
```
This config uses about 29.4 GiB VRAM.
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -39,7 +39,7 @@ wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine

View File

@@ -1,46 +0,0 @@
# Finetune Allenai's Olmo 3 with Axolotl
[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- The example config can be re-used for Olmo and Olmo 2.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Olmo 3 Blog](https://allenai.org/blog/olmo3)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: allenai/Olmo-3-7B-Instruct-SFT
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,17 +6,21 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
Here is an example of how to install from main for pip:
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
@@ -37,7 +41,9 @@ Let us know how it goes. Happy finetuning! 🚀
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources

View File

@@ -37,7 +37,9 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources

View File

@@ -1,5 +1,5 @@
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,34 +1,35 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.48.2
bitsandbytes==0.47.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.3
autoawq==0.2.7.post3
liger-kernel==0.6.1
# END section
packaging==23.2
huggingface_hub>=0.36.0
peft>=0.18.0
tokenizers>=0.22.1
transformers==4.57.3
accelerate==1.11.0
datasets==4.4.1
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.25.0
hf_xet==1.2.0
kernels>=0.9.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.49.1
gradio==5.41.1
modal==1.0.2
pydantic>=2.10.6
pydantic==2.10.6
addict
fire
PyYAML>=6.0
@@ -36,12 +37,13 @@ requests
wandb
einops
colorama
numba>=0.61.2
numpy>=2.2.6
numba
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
art
tensorboard
@@ -49,7 +51,7 @@ python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2025.3.0
gcsfs>=2024.5.0
adlfs>=2024.5.0
ocifs==1.3.2
@@ -63,13 +65,9 @@ immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
# telemetry
posthog==6.7.11
mistral-common==1.8.5

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utility scripts package."""

View File

@@ -0,0 +1,5 @@
"""Benchmark helpers."""
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
initialization completes quickly:
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
DTYPE_MAP = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
def build_config() -> DeepseekV3Config:
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
return DeepseekV3Config(
vocab_size=32_000,
hidden_size=3_072,
intermediate_size=8_192,
moe_intermediate_size=2_560,
num_hidden_layers=20,
num_attention_heads=24,
num_key_value_heads=24,
n_routed_experts=18,
num_experts_per_tok=4,
n_group=6,
topk_group=4,
kv_lora_rank=192,
q_lora_rank=384,
max_position_embeddings=2_048,
rope_theta=10_000.0,
rope_interleave=True,
hidden_act="silu",
initializer_range=0.02,
attention_dropout=0.0,
attention_bias=False,
n_shared_experts=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Directory to save the generated model",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=DTYPE_MAP.keys(),
help="Storage dtype for the checkpoint",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Torch RNG seed for reproducibility",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
output_dir = args.output
output_dir.mkdir(parents=True, exist_ok=True)
config = build_config()
model = DeepseekV3ForCausalLM(config)
dtype = DTYPE_MAP[args.dtype]
model.to(dtype=dtype)
param_count = sum(p.numel() for p in model.parameters())
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
model.save_pretrained(output_dir, safe_serialization=True)
config.save_pretrained(output_dir)
print(f"Saved model and config to {output_dir.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
from __future__ import annotations
import argparse
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import torch
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else:
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.kernels.moe import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
@dataclass
class Scenario:
num_groups: int
m: int
n: int
k: int
SCENARIOS: tuple[Scenario, ...] = (
Scenario(num_groups=4, m=8192, n=4096, k=7168),
Scenario(num_groups=4, m=8192, n=7168, k=2048),
Scenario(num_groups=8, m=4096, n=4096, k=7168),
Scenario(num_groups=8, m=4096, n=7168, k=2048),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--device", default="cuda", choices=["cuda"], help="Execution device"
)
parser.add_argument(
"--dtype",
default="bf16",
choices=["bf16", "fp16", "fp32"],
help="Computation dtype",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M expected by the kernel",
)
return parser.parse_args()
def pick_dtype(name: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[name]
def make_indices(
num_groups: int, group_size: int, device: torch.device
) -> torch.Tensor:
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
return indices.repeat_interleave(group_size)
def timed_call(fn, *args, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
fn(*args)
torch.cuda.synchronize()
return (time.perf_counter() - start) * 1000.0 / iters
def run_scenario(
scenario: Scenario,
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> dict:
if scenario.m % scenario.num_groups != 0:
raise ValueError(
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
)
group_size = scenario.m // scenario.num_groups
if group_size % group_size_m != 0:
raise ValueError(
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
)
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
weights = torch.randn(
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
)
indices = make_indices(scenario.num_groups, group_size, device)
def persistent():
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
def baseline():
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
return {
"scenario": scenario,
"persistent_ms": persistent_ms,
"baseline_ms": baseline_ms,
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
}
def main() -> None: # pragma: no cover - utility script
args = parse_args()
torch.manual_seed(args.seed)
if args.device != "cuda" or not torch.cuda.is_available():
raise SystemExit("CUDA device required for this benchmark")
dtype = pick_dtype(args.dtype)
device = torch.device(args.device)
print(
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
)
print(
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
)
for result in run_all(
SCENARIOS,
dtype=dtype,
device=device,
warmup=args.warmup,
iters=args.iters,
group_size_m=args.group_size,
):
scen = result["scenario"]
print(
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
)
def run_all(
scenarios: Iterable[Scenario],
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> Iterable[dict]:
for scenario in scenarios:
yield run_scenario(
scenario,
dtype=dtype,
device=device,
warmup=warmup,
iters=iters,
group_size_m=group_size_m,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
from types import MethodType
import torch
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
type=int,
default=8192,
help="MoE intermediate projection size",
)
parser.add_argument(
"--n-experts",
type=int,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=8,
help="Number of experts per token",
)
parser.add_argument(
"--groups",
type=int,
default=8,
help="Router groups (must divide n-experts)",
)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Override router to distribute tokens evenly across experts",
)
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(requested)
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
config = DeepseekV3Config(
hidden_size=args.hidden_size,
intermediate_size=args.moe_intermediate_size,
moe_intermediate_size=args.moe_intermediate_size,
n_routed_experts=args.n_experts,
num_experts_per_tok=args.top_k,
n_group=args.groups,
topk_group=max(1, min(args.groups, args.top_k)),
n_shared_experts=1,
)
module = DeepseekV3MoE(config)
module.eval()
return module
@torch.no_grad()
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / iters) * 1000.0
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
dtype = DTYPE_MAP[args.dtype]
if args.n_experts % args.groups != 0:
raise SystemExit("n-experts must be divisible by groups")
if args.top_k > args.n_experts:
raise SystemExit("top-k cannot exceed number of experts")
if device.type == "cuda" and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = getattr(
DeepseekV3MoE,
"_axolotl_triton_original_moe",
DeepseekV3MoE.moe,
)
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
args.hidden_size,
device=device,
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
if args.uniform_routing:
total_assignments = flat_inputs.size(0) * args.top_k
base = total_assignments // args.n_experts
remainder = total_assignments % args.n_experts
counts = torch.full(
(args.n_experts,),
base,
dtype=torch.int64,
device=device,
)
if remainder:
counts[:remainder] += 1
assignments = torch.repeat_interleave(
torch.arange(args.n_experts, device=device), counts
)
assignments = assignments[torch.randperm(assignments.size(0))]
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full(
topk_idx.shape,
1.0 / args.top_k,
device=device,
dtype=torch.float32,
)
def _uniform_gate(self, hidden_states):
flat = hidden_states.view(-1, hidden_states.shape[-1])
token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__
)
baseline_module.gate.forward = _uniform_gate.__get__(
baseline_module.gate, baseline_module.gate.__class__
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_vram = patched_vram = None
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
baseline_vram = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
patched_vram = torch.cuda.max_memory_allocated(device)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
"baseline_vram": baseline_vram,
"patched_vram": patched_vram,
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
if result["baseline_vram"] is not None:
print(
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
)
print(
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
import logging
import sys
from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
LOG = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
help="Override GROUP_SIZE_M for every configuration",
)
parser.add_argument(
"--backends",
default="mg",
help="Comma separated list of backends to benchmark (subset of cg,mg)",
)
parser.add_argument(
"--no-uniform-routing",
action="store_true",
help="Disable uniform routing for every configuration",
)
parser.add_argument(
"--include-mixtral-long",
action="store_true",
help="Add an 8×8192 Mixtral-style run to the sweep",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(
base: dict, args: argparse.Namespace, backend: str
) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"backend": backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"uniform_routing": not args.no_uniform_routing,
}
)
if args.group_size is not None:
combined["group_size"] = args.group_size
return SimpleNamespace(**combined)
ARCHETYPES = (
(
"mixtral",
{
"hidden_size": 4096,
"moe_intermediate_size": 14336,
"n_experts": 8,
"top_k": 2,
"groups": 1,
"group_size": 128,
},
[(4, 2048), (8, 4096)],
),
(
"qwen",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 4,
"groups": 8,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
(
"deepseek_v3",
{
"hidden_size": 12288,
"moe_intermediate_size": 49152,
"n_experts": 128,
"top_k": 8,
"groups": 16,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
)
MIXTRAL_LONG_SHAPES = [(8, 8192)]
def main() -> None: # pragma: no cover - utility script
args = parse_args()
grid = []
for label, base_cfg, shapes in ARCHETYPES:
for batch, seq_len in shapes:
cfg = {
"label": label,
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
continue
grid.append(cfg)
if args.include_mixtral_long:
base_cfg = ARCHETYPES[0][1]
for batch, seq_len in MIXTRAL_LONG_SHAPES:
grid.append(
{
"label": "mixtral_long",
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = (
"model",
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
raw_backends = [
token.strip() for token in args.backends.split(",") if token.strip()
]
if not raw_backends:
raw_backends = ["mg"]
valid_backends = []
for backend in raw_backends:
if backend not in {"cg", "mg"}:
raise SystemExit(f"Unsupported backend '{backend}' requested")
if backend not in valid_backends:
valid_backends.append(backend)
uniform_flag = not args.no_uniform_routing
print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
)
print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
)
for cfg in grid:
for backend in valid_backends:
ns = make_namespace(cfg, args, backend)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["label"],
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
)
if not result["accuracy_ok"]:
LOG.warning(
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
cfg["label"],
backend,
result["max_diff"],
ACCURACY_TOLERANCE,
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
)

View File

@@ -26,6 +26,7 @@ def parse_requirements(extras_require_map):
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -33,6 +34,7 @@ def parse_requirements(extras_require_map):
"triton",
"mamba-ssm",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
@@ -49,7 +51,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.8.0" # default to torch 2.8.0
torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -62,15 +64,8 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -79,7 +74,7 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
@@ -92,6 +87,7 @@ def parse_requirements(extras_require_map):
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
@@ -130,7 +126,7 @@ extras_require = {
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.18.2",
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -165,13 +161,7 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
"opentelemetry": [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -14,8 +14,6 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -33,8 +31,6 @@ LOG = get_logger(__name__)
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -168,7 +164,6 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
@send_errors
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
@@ -202,8 +197,6 @@ def load_cfg(
temp_file.close()
cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
@@ -247,7 +240,6 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
cfg_to_log = {
k: "[REDACTED]" if k in API_KEY_FIELDS else v
for k, v in cfg.items()

View File

@@ -85,7 +85,9 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16)
model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output)

View File

@@ -19,10 +19,7 @@ from axolotl.cli.utils.diffusion import (
launch_diffusion_gradio_ui,
)
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -46,7 +43,6 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -164,7 +160,6 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -7,14 +7,12 @@ import fire
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -23,7 +23,6 @@ from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
@@ -119,7 +118,6 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -17,7 +17,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
@@ -25,7 +24,6 @@ from axolotl.utils.trainer import disable_datasets_caching
LOG = get_logger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -69,7 +69,7 @@ def do_quantize(
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", dtype=torch_dtype
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info(

View File

@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation

View File

@@ -12,9 +12,6 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
}

View File

@@ -9,7 +9,6 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -35,7 +34,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -98,7 +96,6 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
) -> TrainDatasetMeta:

View File

@@ -29,13 +29,7 @@ from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import (
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
)
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
@@ -120,13 +114,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
from axolotl.utils.callbacks.dynamic_checkpoint import (
DynamicCheckpointCallback,
)
callbacks.append(DynamicCheckpointCallback(self.cfg))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -147,12 +134,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
)
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
@@ -164,10 +145,6 @@ class TrainerBuilderBase(abc.ABC):
)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -209,9 +186,9 @@ class TrainerBuilderBase(abc.ABC):
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
@@ -514,7 +491,6 @@ class TrainerBuilderBase(abc.ABC):
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
"dataset_num_proc",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -538,6 +514,9 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -12,7 +12,7 @@ from transformers import (
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.reward_trainer import DataCollatorForPreference
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
@@ -28,6 +28,7 @@ from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
@@ -62,6 +63,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
@@ -453,7 +460,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
DataCollatorForPreference,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
@@ -470,10 +477,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
collator = DataCollatorForPreference
tokenizer = collator_args.pop(0)
kwargs["pad_token_id"] = tokenizer.pad_token_id
kwargs.pop("padding")
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama

View File

@@ -43,7 +43,7 @@ from axolotl.core.trainers.utils import (
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_distributed, is_main_process
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -225,6 +225,17 @@ class AxolotlTrainer(
data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
@@ -283,18 +294,6 @@ class AxolotlTrainer(
):
self.accelerator.even_batches = False
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
dataloader = DataLoader(dataset, **dataloader_params)
# Accelerator.free_memory() will destroy the references, so
@@ -350,11 +349,6 @@ class AxolotlTrainer(
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum()
if is_distributed():
torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM
)
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
@@ -362,11 +356,6 @@ class AxolotlTrainer(
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"):
self.state.total_tokens += num_tokens
else:
self.state.total_tokens = num_tokens
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -571,6 +560,13 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
@@ -631,7 +627,6 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -52,7 +52,6 @@ class GRPOStrategy:
if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)
@@ -126,9 +125,6 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@classmethod
@@ -204,32 +200,3 @@ class GRPOStrategy:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -10,7 +10,6 @@ import torch
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.telemetry.errors import send_errors
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
@@ -64,7 +63,6 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
```
## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models
- apertus
- arcee
- cohere
- cohere2
@@ -45,29 +44,18 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4v
- glm4v_moe
- gpt_oss
- granite
- granitemoe
- granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense
- hunyuan_v1_moe
- lfm2
- lfm2_moe
- lfm2_vl
- llama
- llama4
- llama4_text
- llava
- mistral
- mistral3
- mixtral
- mllama
- olmo
- olmo2
- olmo3
- phi
- phi3
- phi4_multimodal
@@ -77,8 +65,6 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_vl
- qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
)

View File

@@ -7,7 +7,7 @@ import torch
from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
@@ -360,7 +360,7 @@ def _diffusion_step(
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = shift_logits_to_input_positions(outputs.logits)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():

View File

@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask,
)
logits = shift_logits_to_input_positions(outputs.logits)
logits = outputs.logits
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)

View File

@@ -157,10 +157,3 @@ def create_bidirectional_attention_mask(
# Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1)
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
"""Align next-token logits with their input token positions for diffusion."""
if logits.size(1) <= 1:
return logits
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

View File

@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self._loss_function(
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,

View File

@@ -29,8 +29,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
@@ -38,14 +37,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
target = self.model
# Unwrap PEFT wrapper
if hasattr(target, "get_base_model"):
target = target.get_base_model()
# Set on the actual model instance
target._loss_function = loss_fn
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()

View File

@@ -18,9 +18,6 @@ liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
# FLCE-specific
liger_use_token_scaling: true
```
## Supported Models

View File

@@ -16,7 +16,7 @@
Module for handling LIGER input arguments.
"""
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -35,15 +35,6 @@ class LigerArgs(BaseModel):
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
liger_use_token_scaling: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables use_token_scaling in fused_linear_cross_entropy. "
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
)
},
)
@model_validator(mode="before")
@classmethod
@@ -84,18 +75,6 @@ class LigerArgs(BaseModel):
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_use_token_scaling_flce(cls, data):
if data.get("liger_use_token_scaling") and not data.get(
"liger_fused_linear_cross_entropy"
):
raise ValueError(
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate

View File

@@ -48,33 +48,6 @@ class LigerPlugin(BasePlugin):
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.liger_use_token_scaling:
# Patch FLCE to set token_scaling=True for function and class API
from liger_kernel.transformers import functional
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
old_liger_fused_linear_cross_entropy = (
functional.liger_fused_linear_cross_entropy
)
def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
kwargs["use_token_scaling"] = True
return old_liger_fused_linear_cross_entropy(*args, **kwargs)
functional.liger_fused_linear_cross_entropy = (
patched_liger_fused_linear_cross_entropy
)
old_init = LigerFusedLinearCrossEntropyLoss.__init__
def patched_init(self, *args, **kwargs):
kwargs["use_token_scaling"] = True
return old_init(self, *args, **kwargs)
LigerFusedLinearCrossEntropyLoss.__init__ = patched_init
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -0,0 +1,21 @@
"""Mixture-of-Experts kernel implementations."""
from .indices import generate_permute_indices
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
"mg_grouped_gemm",
]

View File

@@ -0,0 +1,5 @@
"""Token permutation utilities for grouped MoE kernels."""
from .indices import generate_permute_indices
__all__ = ["generate_permute_indices"]

View File

@@ -0,0 +1,144 @@
"""Vendored token permutation kernels from TorchTitan."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
__all__ = ["generate_permute_indices"]
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + idx)
length = tl.load(tokens_per_expert_group_ptr + idx)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
for expert_id in range(experts_per_rank):
write_start = write_offsets[expert_id].item()
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = start_index_values[idx].item()
length = tokens_per_expert_group[idx].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)

View File

@@ -0,0 +1,17 @@
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
from .cg_backward import ContiguousGroupedGEMM
from .cg_forward import (
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,290 @@
"""Vendored backward pass for Triton contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
from .cg_forward import cg_grouped_gemm_forward
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_backward_dx(
grad_output_ptr,
b_ptr,
grad_input_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""Compute gradients with respect to inputs."""
pid = tl.program_id(0)
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
tile_m = pid // num_k_tiles
tile_k = pid % num_k_tiles
m_start = tile_m * BLOCK_SIZE_M
k_start = tile_k * BLOCK_SIZE_K
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_m = offs_m < M_TOTAL
mask_k = offs_k < K
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
for n in range(0, N, BLOCK_SIZE_N):
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
mask_n = offs_n < N
mask_go = mask_m[:, None] & mask_n[None, :]
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
grad_input += tl.dot(go, w)
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_gi = mask_m[:, None] & mask_k[None, :]
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
@triton.jit
def _kernel_cg_backward_dw(
grad_output_ptr,
inputs_ptr,
grad_weights_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Simplified kernel for expert weight gradients."""
pid = tl.program_id(0)
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
if expert_id < NUM_EXPERTS:
n_tiles = K // BLOCK_SIZE_K
tile_n = position_id // n_tiles
tile_k = position_id % n_tiles
n_start = tile_n * BLOCK_SIZE_N
k_start = tile_k * BLOCK_SIZE_K
if n_start < N and k_start < K:
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_n = offs_n < N
mask_k = offs_k < K
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
group_start = group_idx * GROUP_SIZE_M
group_expert = tl.load(indices_ptr + group_start)
if group_expert == expert_id:
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
m_start = group_start + m_offset
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
go_ptrs = (
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
)
mask_go = mask_m[:, None] & mask_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_in = mask_m[:, None] & mask_k[None, :]
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
go_t = tl.trans(go)
grad_weights += tl.dot(go_t, inp)
grad_w_ptrs = (
grad_weights_ptr
+ expert_id * N * K
+ offs_n[:, None] * K
+ offs_k[None, :]
)
mask_gw = mask_n[:, None] & mask_k[None, :]
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
def cg_grouped_gemm_backward_weights(
grad_output: torch.Tensor,
inputs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for expert weights."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
_, K = inputs.shape
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
grad_weights = torch.zeros(
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
)
block_size_n = min(128, N)
block_size_k = min(32, K)
block_size_m = min(32, group_size_m)
n_tiles = triton.cdiv(N, block_size_n)
k_tiles = triton.cdiv(K, block_size_k)
grid = (num_experts * n_tiles * k_tiles,)
_kernel_cg_backward_dw[grid](
grad_output,
inputs,
grad_weights,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_M=block_size_m,
)
return grad_weights
def cg_grouped_gemm_backward_inputs(
grad_output: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for inputs."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
num_experts, _, K = expert_weights.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
grad_inputs = torch.zeros(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
)
_kernel_cg_backward_dx[grid](
grad_output,
expert_weights,
grad_inputs,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return grad_inputs
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function with full backward support."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
ctx.save_for_backward(inputs, expert_weights, expert_indices)
ctx.group_size_m = group_size_m
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output):
inputs, expert_weights, expert_indices = ctx.saved_tensors
group_size_m = ctx.group_size_m
grad_output = grad_output.contiguous()
num_experts = expert_weights.shape[0]
grad_inputs = cg_grouped_gemm_backward_inputs(
grad_output=grad_output,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
grad_weights = cg_grouped_gemm_backward_weights(
grad_output=grad_output,
inputs=inputs,
expert_indices=expert_indices,
num_experts=num_experts,
group_size_m=group_size_m,
)
grad_indices = None
grad_group_size_m = None
return grad_inputs, grad_weights, grad_indices, grad_group_size_m

View File

@@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Vendored forward Triton contiguous grouped GEMM kernels."""
import torch
import triton
import triton.language as tl
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * super_group_m
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_persistent_forward(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_SMS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
SUPER_GROUP_M: tl.constexpr = 32,
):
"""
Contiguous Grouped GEMM kernel forward (persistent variant).
"""
c_type = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = SUPER_GROUP_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
tile_m_idx, tile_n_idx = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
m_start = tile_m_idx * BLOCK_SIZE_M
n_start = tile_n_idx * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = (
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
accumulator += tl.dot(a, b.T)
tile_id_c += NUM_SMS
tile_m_idx, tile_n_idx = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_c = mask_m[:, None] & mask_n[None, :]
c = accumulator.to(tl.float32)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_forward_aligned(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""
Contiguous Grouped GEMM kernel forward for aligned inputs.
"""
pid = tl.program_id(0)
c_type = c_ptr.dtype.element_ty
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
tile_m = pid // num_n_tiles
tile_n = pid % num_n_tiles
m_start = tile_m * BLOCK_SIZE_M
n_start = tile_n * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
acc += tl.dot(a, b.T)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask_c = mask_m[:, None] & mask_n[None, :]
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
def cg_grouped_gemm_forward(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = (NUM_SMS, 1, 1)
_kernel_cg_persistent_forward[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
NUM_SMS=NUM_SMS,
)
return output
def cg_grouped_gemm_forward_dynamic(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
_kernel_cg_forward_aligned[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return output
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function for contiguous grouped GEMM forward pass only."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover - not implemented
raise NotImplementedError("Backward pass not implemented")
def cg_grouped_gemm(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Convenience wrapper for the forward-only autograd function."""
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
return ContiguousGroupedGEMM.apply(
inputs, expert_weights, expert_indices, group_size_m
)

View File

@@ -0,0 +1,31 @@
"""Reference implementation for contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""Simple PyTorch implementation for verification."""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)
expert_idx = expert_indices[i].item()
expert_weight = expert_weights[expert_idx]
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
return output

View File

@@ -0,0 +1,209 @@
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver
class CudaUtils:
"""Helper utilities for CUDA specific Triton features."""
@staticmethod
def is_cuda() -> bool:
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels."""
class KernelParamWrapper:
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(
self, name: str
) -> "TmaDescriptorHelper.KernelParamWrapper":
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
HOPPER_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
STANDARD_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
def early_config_prune(configs, args, **kwargs):
"""Filter out configurations that would exceed shared memory capacity."""
k = kwargs.get("K", 0)
valid_configs = [
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
]
if not valid_configs and configs:
return [
min(
configs,
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
)
]
return valid_configs

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

View File

@@ -0,0 +1,761 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - flat index forward kernel is derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import logging
from typing import Tuple
import torch
import triton
import triton.language as tl
from .tma_autotuning import (
_NV_CONFIGS,
CudaUtils,
early_config_prune,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_allocator_registered = False
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
def _ensure_triton_allocator() -> None:
global _allocator_registered
if not _allocator_registered:
triton.set_allocator(_torch_allocator)
_allocator_registered = True
# ============== Start Triton Kernels ===============
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_hopper(
a_ptr,
b_ptr,
c_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
USE_EPILOGUE_SUBTILING: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Flat index style forward kernel for Hopper using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = c_ptr.dtype.element_ty
n_size = N // G
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
a = a_desc.load([m_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a))
b = b_desc.load([global_n_offset, k_offset])
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = (
c_ptr
+ global_row[:, None] * n_size
+ local_col_offsets[None, :]
)
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
"""
Backward pass for grouped GEMM with Triton, where grouping is M*G
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
"""
# ---- dx flat linear indexed ----
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dx_tma(
grad_output_ptr,
w_ptr,
grad_input_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Compute grad_input = grad_output @ w using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_input_ptr.dtype.element_ty
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
w_desc = tl.make_tensor_descriptor(
w_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
k_offset = tile_k_index * BLOCK_SIZE_K
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
grad_y = grad_output_desc.load([m_offset, n_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
w_tile = w_desc.load([n_offset, k_offset])
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dw_tma(
x_ptr,
grad_output_ptr,
grad_weight_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
) -> None:
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_weight_ptr.dtype.element_ty
x_desc = tl.make_tensor_descriptor(
x_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_n_tiles * num_k_tiles
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
tile_n_idx = tile_idx % num_n_tiles
tile_k_idx = tile_idx // num_n_tiles
n_offset = tile_n_idx * BLOCK_SIZE_N
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
k_offset = tile_k_idx * BLOCK_SIZE_K
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
M_end = tl.full([], 0, dtype=tl.int32)
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32)
x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
grad_block = grad_output_desc.load([m_offset, n_offset])
grad_mask = row_mask[:, None] & n_mask[None, :]
grad_block = tl.where(
grad_mask, grad_block, tl.zeros_like(grad_block)
)
contribution = tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
# ======== End Triton kernels ========
# ======== End Triton kernels ========
# ======== Triton wrapper functions ========
# ----- main forward pass wrapper -----
def grouped_gemm_forward(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""Grouped GEMM forward using Hopper TMA kernels."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
if using_fp8:
raise NotImplementedError(
"FP8 path not implemented with the new Triton API yet"
)
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M_total, K = x.shape
N = w.shape[0]
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
if M_total == 0:
return y
NUM_SMS = CudaUtils.get_num_sms()
USE_EPILOGUE_SUBTILING = False
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_forward_hopper[grid](
x,
w,
y,
m_sizes,
M_total,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
)
return y
# ======== Improved Backward =============
def grouped_gemm_backward(
grad_output: torch.Tensor,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unified backward pass for grouped GeMM with M*G grouping.
Uses optimized TMA-based implementations for both dx and dw when available.
Args:
grad_output: Gradient of output, shape [M_total, N]
x: Input tensor from forward pass, shape [M_total, K]
w: Weight tensor from forward pass, shape [N, K]
m_sizes: Group sizes tensor, shape [G]
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
Returns:
Tuple of gradients with respect to x and w: (grad_x, grad_w)
"""
logging.info("Starting unified grouped_gemm_backward")
# do this once, seems expensive
NUM_SMS = CudaUtils.get_num_sms()
# Basic validation
M_total, K_x = x.shape
M_grad, N = grad_output.shape
N_w, K_w = w.shape
# Check dimensions
if K_x != K_w:
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
if M_total != M_grad:
raise ValueError(
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
)
# Check total M matches sum of group sizes
sum_m_sizes = m_sizes.sum().item()
if M_total != sum_m_sizes:
raise ValueError(
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
)
# Make sure inputs are contiguous
grad_output = grad_output.contiguous()
x = x.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
# Check TMA support
if use_tma and not CudaUtils.verify_tma():
logging.info("TMA requested but not supported on this device")
use_tma = False
# Compute grad_x using flat linear implementation
try:
logging.info("Computing grad_x with flat linear kernel")
# Use TMA-optimized implementation
grad_x = grouped_gemm_dx_tma(
grad_output=grad_output,
w=w,
m_sizes=m_sizes,
num_sms=NUM_SMS,
tma_size=tma_size,
)
except Exception as e:
logging.error(f"Error in grad_x computation: {e}")
raise
# Compute grad_w using flat linear style implementation
try:
logging.info("Computing grad_w with flat linear kernel")
grad_w = grouped_gemm_dw_tma(
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
)
except Exception as e:
logging.error(f"Error in grad_w computation: {e}")
raise
return grad_x, grad_w
# ----- dx backward pass wrapper -----
def grouped_gemm_dx_tma(
grad_output: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_x using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Optimized dx computation requires TMA support")
grad_output = grad_output.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
M_total, N = grad_output.shape
N_w, K = w.shape
if N != N_w:
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
grad_x = torch.empty(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dx_tma[grid](
grad_output,
w,
grad_x,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_x
# ======== dw wrapper function ==========
def grouped_gemm_dw_tma(
x: torch.Tensor,
grad_output: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_w using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
x = x.contiguous()
grad_output = grad_output.contiguous()
m_sizes = m_sizes.contiguous()
M_total, K = x.shape
M_grad, N = grad_output.shape
if M_total != M_grad:
raise ValueError("x and grad_output must have matching batch dimension")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dw_tma[grid](
x,
grad_output,
grad_w,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_w
# ======== End Backwards Wrapper Functions =============
# ======== PyTorch wrapper functions ========
class GroupedGemmMg(torch.autograd.Function):
"""
Autograd function for GroupedGEMM with M*G grouping.
Supports both standard and FP8 quantized operations.
"""
@staticmethod
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
"""
Forward pass of GroupedGEMM.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
# Use regular forward without quantization
output = grouped_gemm_forward(
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
)
# Save inputs and parameters for backward pass
ctx.save_for_backward(x, w, m_sizes)
ctx.use_tma = use_tma
ctx.tma_size = tma_size
ctx.save_for_backward(x, w, m_sizes)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass of M*G GroupedGEMM.
Args:
grad_output: Gradient of output, shape [M_total, N]
Returns:
Tuple of gradients:
- grad_x: Gradient with respect to x, shape [M_total, K]
- grad_w: Gradient with respect to w, shape [N, K]
- None: Gradient with respect to m_sizes (not differentiable)
- None: Gradient with respect to use_tma (not differentiable)
- None: Gradient with respect to tma_size (not differentiable)
"""
# Retrieve saved tensors and parameters
x, w, m_sizes = ctx.saved_tensors
use_tma = ctx.use_tma
tma_size = ctx.tma_size
# Compute gradients using the unified implementation
grad_x, grad_w = grouped_gemm_backward(
grad_output=grad_output,
x=x,
w=w,
m_sizes=m_sizes,
use_tma=use_tma,
tma_size=tma_size,
)
# Return gradients for all inputs (None for non-differentiable parameters)
return grad_x, grad_w, None, None, None, None
def mg_grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
Supports both standard precision and FP8 quantized operations.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)

View File

@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -20,7 +20,6 @@ from peft import (
from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -102,8 +101,6 @@ def load_lora(
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
# Determine the correct PEFT task type
model_cls = type(model).__name__
@@ -175,7 +172,6 @@ def load_lora(
return model, lora_config
@send_errors
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,

View File

@@ -49,7 +49,6 @@ from axolotl.loaders.utils import (
load_model_config,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.telemetry.errors import send_errors
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -159,7 +158,6 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
@@ -517,6 +515,9 @@ class ModelLoader:
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -551,7 +552,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -577,7 +580,9 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
"load_in_8bit", False
):
bnb_config = {
"load_in_8bit": True,
}
@@ -591,6 +596,11 @@ class ModelLoader:
**bnb_config,
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:

View File

@@ -190,6 +190,15 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
)
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
@@ -457,7 +466,7 @@ class PatchManager:
and self.cfg.flash_attention
and not self.inference
):
# TODO(MengqingCao): split these patches separately
# TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,

View File

@@ -1,47 +1,27 @@
"""Processor loading functionality for multi-modal models"""
from typing import Any
import transformers
from transformers import (
AutoProcessor,
PreTrainedTokenizerBase,
)
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
"""
Transformers v5 stops reading the sub-processor.
We need to patch this, so both processors use this.
"""
import transformers.tokenization_mistral_common as tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()
from transformers import VoxtralProcessor
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
)
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
@@ -52,6 +32,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
# Attempt to load image size from processor if available

Some files were not shown because too many files have changed in this diff Show More