Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f19169eb0 | ||
|
|
876941ffd0 | ||
|
|
d65e1b960c | ||
|
|
0a23ae08f7 | ||
|
|
fc2d63ee5f | ||
|
|
c119382337 | ||
|
|
6c8c73e5a4 | ||
|
|
a260d330ed | ||
|
|
da17c7c0d9 | ||
|
|
cada93cee5 | ||
|
|
56162f71db | ||
|
|
6c44afaea1 | ||
|
|
234931d512 | ||
|
|
6a8baf8fa7 | ||
|
|
1eaf4d7418 | ||
|
|
4b8bc52424 | ||
|
|
28cc085283 | ||
|
|
8e2a102cca | ||
|
|
753906cfc7 | ||
|
|
b6b8db805a | ||
|
|
653f90be25 | ||
|
|
945c8aeb10 | ||
|
|
e672d37f33 |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -70,6 +70,11 @@ You can skip certain CI checks by including specific keywords in your commit mes
|
|||||||
|
|
||||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||||
|
|
||||||
|
Use the pre-commit linter to ensure that your code is formatted consistently.
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
### Commit Messages
|
### Commit Messages
|
||||||
|
|
||||||
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
|
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
|
||||||
|
|||||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -51,6 +51,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -173,6 +181,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
13
.github/workflows/multi-gpu-e2e.yml
vendored
13
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -35,12 +35,6 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras: fbgemm-gpu
|
axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras: "fbgemm-gpu"
|
|
||||||
num_gpus: 2
|
|
||||||
- cuda: 129
|
- cuda: 129
|
||||||
cuda_version: 12.9.1
|
cuda_version: 12.9.1
|
||||||
python_version: "3.12"
|
python_version: "3.12"
|
||||||
@@ -55,6 +49,13 @@ jobs:
|
|||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
# axolotl_extras: fbgemm-gpu
|
# axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
axolotl_extras: "fbgemm-gpu"
|
||||||
|
num_gpus: 2
|
||||||
|
dockerfile: "Dockerfile-uv.jinja"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
26
.github/workflows/tests-nightly.yml
vendored
26
.github/workflows/tests-nightly.yml
vendored
@@ -18,15 +18,27 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
|
prime-cdn-s3-cache:
|
||||||
|
name: Prefetch S3 once to prime the CDN cache
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- name: Restore Cache from S3
|
||||||
|
id: hf-cache-restore-s3
|
||||||
|
run: |
|
||||||
|
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: [prime-cdn-s3-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -102,16 +114,23 @@ jobs:
|
|||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
dockerfile: "Dockerfile-uv.jinja"
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -132,6 +151,7 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
|
|||||||
56
.github/workflows/tests.yml
vendored
56
.github/workflows/tests.yml
vendored
@@ -46,21 +46,32 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
|
prime-cdn-s3-cache:
|
||||||
|
name: Prefetch S3 once to prime the CDN cache
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
|
timeout-minutes: 10
|
||||||
|
steps:
|
||||||
|
- name: Restore Cache from S3
|
||||||
|
id: hf-cache-restore-s3
|
||||||
|
run: |
|
||||||
|
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
# needs: [preload-cache]
|
needs: [prime-cdn-s3-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11", "3.12"]
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
exclude:
|
# exclude:
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.8.0"
|
# pytorch_version: "2.8.0"
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.9.0"
|
# pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -146,17 +157,18 @@ jobs:
|
|||||||
name: PyTest from Source Dist
|
name: PyTest from Source Dist
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
|
needs: [prime-cdn-s3-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11", "3.12"]
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
exclude:
|
# exclude:
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.8.0"
|
# pytorch_version: "2.8.0"
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.9.0"
|
# pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 30
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
- name: cleanup node
|
||||||
@@ -326,6 +338,12 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -369,9 +387,9 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 129
|
- cuda: 128
|
||||||
cuda_version: 12.9.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.12"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.14.10
|
rev: v0.15.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
@@ -26,7 +26,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.9.2
|
rev: 1.9.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
32
README.md
32
README.md
@@ -29,8 +29,23 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
- 2026/03:
|
||||||
|
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||||
|
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||||
|
- 2026/02:
|
||||||
|
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||||
|
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
|
||||||
|
- 2026/01:
|
||||||
|
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
|
||||||
|
- 2025/12:
|
||||||
|
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||||
|
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Expand older updates</summary>
|
||||||
|
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
- 2025/07:
|
- 2025/07:
|
||||||
@@ -39,15 +54,10 @@
|
|||||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Expand older updates</summary>
|
|
||||||
|
|
||||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||||
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
|
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||||
@@ -62,10 +72,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and
|
|||||||
Features:
|
Features:
|
||||||
|
|
||||||
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
||||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
|
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ website:
|
|||||||
- docs/sequence_parallelism.qmd
|
- docs/sequence_parallelism.qmd
|
||||||
- docs/gradient_checkpointing.qmd
|
- docs/gradient_checkpointing.qmd
|
||||||
- docs/nd_parallelism.qmd
|
- docs/nd_parallelism.qmd
|
||||||
|
- docs/expert_quantization.qmd
|
||||||
|
|
||||||
- section: "Troubleshooting"
|
- section: "Troubleshooting"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
||||||
RUN uv pip install torchvision
|
RUN uv pip install torchvision
|
||||||
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
||||||
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -3,6 +3,12 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
|
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||||
|
hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
|
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
|
hf download "microsoft/Phi-4-reasoning"
|
||||||
|
hf download "microsoft/Phi-3.5-mini-instruct"
|
||||||
|
|
||||||
# Run unit tests with initial coverage report
|
# Run unit tests with initial coverage report
|
||||||
pytest -v --durations=10 -n8 \
|
pytest -v --durations=10 -n8 \
|
||||||
--ignore=tests/e2e/ \
|
--ignore=tests/e2e/ \
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
67
docs/expert_quantization.qmd
Normal file
67
docs/expert_quantization.qmd
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
title: "MoE Expert Quantization"
|
||||||
|
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
|
||||||
|
---
|
||||||
|
|
||||||
|
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
|
||||||
|
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
|
||||||
|
weights being loaded in full bf16 precision and causing massive VRAM usage.
|
||||||
|
|
||||||
|
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
|
||||||
|
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
|
||||||
|
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
|
||||||
|
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Enable expert quantization in your Axolotl config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
quantize_moe_experts: true
|
||||||
|
```
|
||||||
|
|
||||||
|
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
|
||||||
|
|
||||||
|
### Expert LoRA targeting
|
||||||
|
|
||||||
|
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- mlp.experts.gate_up_proj
|
||||||
|
- mlp.experts.down_proj
|
||||||
|
# - mlp.gate.weight # router
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
`lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
|
||||||
|
- CUDA GPUs only (not tested with ROCm or other backends)
|
||||||
|
- FSDP2 compatible for distributed training
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
|
||||||
|
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
|
||||||
|
- Total model parameter count may display incorrectly (trainable param count is correct).
|
||||||
|
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
|
||||||
|
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
|
||||||
|
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
|
||||||
|
- DeepSpeed has not been tested.
|
||||||
|
|
||||||
|
## Implementation details
|
||||||
|
|
||||||
|
The quantization is applied by patching transformers to intercept weight loading.
|
||||||
|
When a 3D+ CUDA tensor with "expert" in its name is detected:
|
||||||
|
|
||||||
|
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
|
||||||
|
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
|
||||||
|
|
||||||
|
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
|
||||||
|
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
|
||||||
|
|
||||||
|
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).
|
||||||
@@ -66,6 +66,15 @@ Provides efficient Triton kernels to improve training speed and reduce memory us
|
|||||||
|
|
||||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||||
|
|
||||||
|
### Expert Kernels
|
||||||
|
|
||||||
|
Optimized kernel implementations for Mixture of Experts (MoE) model training.
|
||||||
|
|
||||||
|
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
|
||||||
|
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
|
||||||
|
|
||||||
|
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
|
||||||
|
|
||||||
## Long Context Models
|
## Long Context Models
|
||||||
|
|
||||||
Techniques to train models on sequences longer than their original context window.
|
Techniques to train models on sequences longer than their original context window.
|
||||||
@@ -131,3 +140,10 @@ Simulates quantization effects during training, helping the model adapt and pote
|
|||||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||||
|
|
||||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||||
|
|
||||||
|
### MoE Expert Quantization
|
||||||
|
|
||||||
|
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
|
||||||
|
|
||||||
|
- **Config:** `quantize_moe_experts: true`
|
||||||
|
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
72
examples/glm45/README.md
Normal file
72
examples/glm45/README.md
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Finetune Z.ai's GLM-4.5-Air with Axolotl
|
||||||
|
|
||||||
|
[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# QLoRA (1x80GB @ ~63.4GiB/GPU)
|
||||||
|
axolotl train examples/glm45/glm-45-air-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dataset
|
||||||
|
|
||||||
|
In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": "...", // or have </think>...</think> in `content`
|
||||||
|
"content": "..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure you set the below extra attributes if needed:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
message_property_mappings:
|
||||||
|
role: role
|
||||||
|
content: content
|
||||||
|
|
||||||
|
# tool_calls: tool_calls # uncomment if using tools
|
||||||
|
# reasoning_content: reasoning_content # uncomment if have reasoning
|
||||||
|
|
||||||
|
# Uncomment if training on tool role (you would rarely if ever need this)
|
||||||
|
# eot_tokens:
|
||||||
|
# - <|observation|>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- The role name for tools in this template is `tool`.
|
||||||
|
- You will see this Axolotl WARNING — this is expected as the template does not use EOS:
|
||||||
|
```
|
||||||
|
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
|
||||||
|
```
|
||||||
|
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.
|
||||||
|
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)
|
||||||
|
- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
64
examples/glm45/glm-45-air-qlora.yaml
Normal file
64
examples/glm45/glm-45-air-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: zai-org/GLM-4.5-Air
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
quantize_moe_experts: true # important
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 8
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
65
examples/glm47-flash/README.md
Normal file
65
examples/glm47-flash/README.md
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||||
|
|
||||||
|
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# QLoRA
|
||||||
|
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||||
|
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||||
|
axolotl train examples/glm47-flash/qlora.yaml
|
||||||
|
|
||||||
|
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||||
|
axolotl train examples/glm47-flash/qlora_fsdp.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LoRA
|
||||||
|
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||||
|
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||||
|
axolotl train examples/glm47-flash/lora.yaml
|
||||||
|
|
||||||
|
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||||
|
axolotl train examples/glm47-flash/lora_fsdp.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### MoE Expert Quantization & Expert LoRA
|
||||||
|
|
||||||
|
This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **lora_target_linear**: Incompatible for this model.
|
||||||
|
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||||
|
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||||
|
- `temperature: 1.0`
|
||||||
|
- `top_p: 0.95`
|
||||||
|
- `max_new_tokens: 131072`
|
||||||
|
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||||
|
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
65
examples/glm47-flash/lora.yaml
Normal file
65
examples/glm47-flash/lora.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
75
examples/glm47-flash/lora_fsdp.yaml
Normal file
75
examples/glm47-flash/lora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
65
examples/glm47-flash/qlora.yaml
Normal file
65
examples/glm47-flash/qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
75
examples/glm47-flash/qlora_fsdp.yaml
Normal file
75
examples/glm47-flash/qlora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-3B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: yahma/alpaca-cleaned
|
||||||
|
type: alpaca
|
||||||
|
split: train[:95%]
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out/
|
||||||
|
dataset_prepared_path: ./outputs/dataset_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: mxfp4
|
||||||
|
weight_dtype: mxfp4
|
||||||
|
group_size: 32
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
|
||||||
|
cosine_constant_lr_ratio: 0
|
||||||
|
cosine_min_lr_ratio: 1.0
|
||||||
|
learning_rate: 2e-5
|
||||||
|
save_only_model: true
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install Qwen3-Next transformers commit
|
|
||||||
```bash
|
|
||||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Install FLA for improved performance
|
3. Install FLA for improved performance
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the finetuning example:
|
4. Run the finetuning example:
|
||||||
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
|||||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 45.62 GiB VRAM.
|
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ plugins:
|
|||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
@@ -25,7 +27,7 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 8
|
lora_alpha: 8
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- linear_attn.in_proj_ba
|
- linear_attn.in_proj_ba
|
||||||
- linear_attn.in_proj_qkvz
|
- linear_attn.in_proj_qkvz
|
||||||
@@ -34,12 +36,19 @@ lora_target_modules:
|
|||||||
- shared_expert.down_proj
|
- shared_expert.down_proj
|
||||||
- shared_expert.gate_proj
|
- shared_expert.gate_proj
|
||||||
- shared_expert_gate
|
- shared_expert_gate
|
||||||
- mlp.gate
|
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
|
|||||||
71
examples/qwen3.5/122b-a10b-moe-qlora.yaml
Normal file
71
examples/qwen3.5/122b-a10b-moe-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-122B-A10B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
#lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
72
examples/qwen3.5/27b-qlora.yaml
Normal file
72
examples/qwen3.5/27b-qlora.yaml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-27B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
|
||||||
|
# the text-only path. For multimodal (image+text) fine-tuning, add image
|
||||||
|
# columns to your dataset following axolotl's multimodal dataset format.
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
# Uncomment below to also target the linear attention projections.
|
||||||
|
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||||
|
# - linear_attn.in_proj_qkv
|
||||||
|
# - linear_attn.in_proj_z
|
||||||
|
# - linear_attn.out_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
70
examples/qwen3.5/35b-a3b-moe-qlora.yaml
Normal file
70
examples/qwen3.5/35b-a3b-moe-qlora.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-35B-A3B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
#lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
72
examples/qwen3.5/7b-lora-vision.yaml
Normal file
72
examples/qwen3.5/7b-lora-vision.yaml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: Qwen/Qwen3.5-7B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
|
||||||
|
# Vision and text tokens are processed together by the same transformer layers.
|
||||||
|
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
|
||||||
|
|
||||||
|
# These 3 lines are required for vision/multimodal training
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen3_5
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# Targets the language model attention and MLP layers.
|
||||||
|
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
|
||||||
|
# the same transformer stack, so standard attention targets work for both modalities.
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
# Uncomment to also target the linear attention (GatedDeltaNet) projections:
|
||||||
|
# - linear_attn.in_proj_qkv
|
||||||
|
# - linear_attn.in_proj_z
|
||||||
|
# - linear_attn.out_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
61
examples/qwen3.5/README.md
Normal file
61
examples/qwen3.5/README.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Finetune Qwen3.5 with Axolotl
|
||||||
|
|
||||||
|
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
|
||||||
|
|
||||||
|
Available configs:
|
||||||
|
|
||||||
|
| Config | Model | Type |
|
||||||
|
|---|---|---|
|
||||||
|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
|
||||||
|
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
|
||||||
|
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
|
||||||
|
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||||
|
```bash
|
||||||
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||||
|
```
|
||||||
|
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||||
|
|
||||||
|
4. Run a finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||||
|
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||||
|
|
||||||
|
# MoE 35B-A3B text-only (QLoRA)
|
||||||
|
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||||
|
|
||||||
|
# MoE 122B-A10B text-only (QLoRA)
|
||||||
|
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||||
|
|
||||||
|
# 7B vision+text (LoRA, multimodal dataset)
|
||||||
|
axolotl train examples/qwen3.5/7b-lora-vision.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||||
|
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||||
|
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
|
||||||
|
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 24.9 GiB VRAM.
|
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
base_model: arcee-ai/Trinity-Nano-Preview
|
||||||
trust_remote_code: true
|
|
||||||
revision_of_model: 2ee94b0
|
revision_of_model: 2ee94b0
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
|||||||
@@ -12,13 +12,16 @@ packaging==26.0
|
|||||||
huggingface_hub>=1.1.7
|
huggingface_hub>=1.1.7
|
||||||
peft>=0.18.1
|
peft>=0.18.1
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==5.2.0
|
transformers==5.3.0
|
||||||
accelerate==1.12.0
|
accelerate==1.13.0
|
||||||
datasets==4.5.0
|
datasets==4.5.0
|
||||||
deepspeed>=0.18.3
|
deepspeed>=0.18.6,<0.19.0
|
||||||
trl==0.28.0
|
trl==0.29.0
|
||||||
hf_xet==1.2.0
|
hf_xet==1.3.2
|
||||||
kernels==0.12.1
|
kernels==0.12.2
|
||||||
|
|
||||||
|
fla-core==0.4.1
|
||||||
|
flash-linear-attention==0.4.1
|
||||||
|
|
||||||
trackio>=0.16.1
|
trackio>=0.16.1
|
||||||
typing-extensions>=4.15.0
|
typing-extensions>=4.15.0
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
|
||||||
)
|
)
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -27,9 +27,16 @@ def parse_requirements(extras_require_map):
|
|||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
install_xformers = platform.machine() != "aarch64"
|
install_xformers = platform.machine() != "aarch64"
|
||||||
if platform.machine() == "aarch64":
|
if platform.machine() == "aarch64":
|
||||||
# skip torchao on ARM64
|
# skip on ARM64
|
||||||
|
skip_packages = [
|
||||||
|
"torchao",
|
||||||
|
"fla-core",
|
||||||
|
"flash-linear-attention",
|
||||||
|
]
|
||||||
_install_requires = [
|
_install_requires = [
|
||||||
req for req in _install_requires if "torchao" not in req
|
req
|
||||||
|
for req in _install_requires
|
||||||
|
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||||
]
|
]
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# skip packages not compatible with OSX
|
# skip packages not compatible with OSX
|
||||||
|
|||||||
@@ -6,5 +6,6 @@ from axolotl.logging_config import configure_logging
|
|||||||
|
|
||||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||||
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
|
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
|
||||||
|
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
merge_lora=True,
|
merge_lora=True,
|
||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
|
quantize_moe_experts=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
context_parallel_size=None,
|
context_parallel_size=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
|
|||||||
@@ -12,10 +12,14 @@ MOE_ARCH_BLOCK = {
|
|||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
|
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
"deepseek_v3": "DeepseekV3MoE",
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
"afmoe": "AfmoeMoE",
|
"afmoe": "AfmoeMoE",
|
||||||
|
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||||
|
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||||
|
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,11 +120,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
else:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl is RLType.SIMPO:
|
if self.cfg.rl is RLType.SIMPO:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from transformers import PreTrainedModel, Trainer
|
|||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.experimental.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
@@ -720,12 +720,16 @@ class AxolotlTrainer(
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
# fix for Context Parallel save
|
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||||
if state_dict is None:
|
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||||
state_dict = self.accelerator.get_state_dict(self.model)
|
if (
|
||||||
if state_dict is not None:
|
state_dict is not None
|
||||||
|
and self.axolotl_cfg
|
||||||
|
and self.axolotl_cfg.context_parallel_size
|
||||||
|
and self.axolotl_cfg.context_parallel_size > 1
|
||||||
|
):
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -761,7 +765,11 @@ class AxolotlTrainer(
|
|||||||
metadata={"format": "pt"},
|
metadata={"format": "pt"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
self.model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
is_main_process=self.accelerator.is_main_process,
|
||||||
|
)
|
||||||
|
|
||||||
if self.processing_class is not None:
|
if self.processing_class is not None:
|
||||||
self.processing_class.save_pretrained(output_dir)
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
|||||||
@@ -25,17 +25,13 @@ class DPOStrategy:
|
|||||||
# Label smoothing is not compatible with IPO
|
# Label smoothing is not compatible with IPO
|
||||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
training_args_kwargs["max_completion_length"] = None
|
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||||
if cfg.dpo_norm_loss is not None:
|
if cfg.dpo_norm_loss is not None:
|
||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
if cfg.dpo_use_logits_to_keep is not None:
|
|
||||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
|
||||||
if cfg.dpo_use_liger_kernel is not None:
|
if cfg.dpo_use_liger_kernel is not None:
|
||||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -103,10 +103,10 @@ class AxolotlDPOTrainer(
|
|||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
if self.args.dpo_norm_loss:
|
if self.args.dpo_norm_loss:
|
||||||
# fmt: off
|
# fmt: off
|
||||||
loss_type: str = self.loss_type # type: ignore[has-type]
|
loss_type: list[str] = self.loss_type # type: ignore[has-type]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# concatenated_forward handles avg token logprob for ipo case already
|
# concatenated_forward handles avg token logprob for ipo case already
|
||||||
self.loss_type = "ipo"
|
self.loss_type = ["ipo"]
|
||||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
return res
|
return res
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
|
|||||||
|
|
||||||
return optimizer_grouped_parameters
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self, model=None):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
|
|||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.optimizer_cls_and_kwargs is None
|
and self.optimizer_cls_and_kwargs is None
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(model=model)
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model if model is None else model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.optimizer
|
not self.optimizer
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -88,9 +88,9 @@ plugins:
|
|||||||
- qwen2_vl
|
- qwen2_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_5
|
- qwen3_5
|
||||||
|
- qwen3_5_text
|
||||||
- qwen3_5_moe
|
- qwen3_5_moe
|
||||||
- qwen3_5_moe_vl
|
- qwen3_5_moe_text
|
||||||
- qwen3_5_vl
|
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
- qwen3_next
|
- qwen3_next
|
||||||
- qwen3_vl
|
- qwen3_vl
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
|
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -21,23 +21,57 @@ plugins:
|
|||||||
- axolotl.integrations.kernels.KernelsPlugin
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
|
||||||
use_kernels: true
|
use_kernels: true
|
||||||
|
|
||||||
|
# Choose one (mutually exclusive):
|
||||||
use_scattermoe: true
|
use_scattermoe: true
|
||||||
|
# OR
|
||||||
|
use_sonicmoe: true
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
|
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
|
||||||
|
|
||||||
|
### SonicMoE installation
|
||||||
|
|
||||||
|
**Prerequisites:**
|
||||||
|
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
|
||||||
|
- CUDA 12.9+ (13.0+ for B300)
|
||||||
|
- PyTorch 2.7+ (2.9.1 recommended)
|
||||||
|
- For B300: Triton 3.6.0
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
|
||||||
|
|
||||||
|
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
|
||||||
|
|
||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
The `KernelsPlugin` runs before model loading and:
|
The `KernelsPlugin` runs before model loading and:
|
||||||
|
|
||||||
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
|
### ScatterMoE
|
||||||
|
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
|
||||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||||
|
|
||||||
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
|
### SonicMoE
|
||||||
|
1. Resolves the model's MoE block class(es) from `constants.py`.
|
||||||
|
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||||
|
3. Supports both softmax->topk and sigmoid->topk routing strategies.
|
||||||
|
|
||||||
|
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
|
||||||
|
|
||||||
|
#### Supported Models
|
||||||
|
|
||||||
|
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
|
||||||
|
|
||||||
|
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
|
||||||
|
|
||||||
|
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||||
|
|
||||||
## Note on MegaBlocks
|
## Note on MegaBlocks
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,18 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KernelsArgs(BaseModel):
|
class KernelsArgs(BaseModel):
|
||||||
use_scattermoe: bool | None = True
|
use_scattermoe: bool | None = None
|
||||||
|
use_sonicmoe: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_mutually_exclusive(cls, data):
|
||||||
|
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
|
||||||
|
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -36,11 +47,11 @@ class KernelsArgs(BaseModel):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def disable_mlp_kernel_scattermoe(cls, data):
|
def disable_mlp_kernel(cls, data):
|
||||||
if data.get("use_scattermoe") is True:
|
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
|
||||||
if data.get("lora_mlp_kernel") is True:
|
if data.get("lora_mlp_kernel") is True:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
|
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
|
||||||
)
|
)
|
||||||
data["lora_mlp_kernel"] = False
|
data["lora_mlp_kernel"] = False
|
||||||
data["mlp_kernel"] = False
|
data["mlp_kernel"] = False
|
||||||
|
|||||||
68
src/axolotl/integrations/kernels/constants.py
Normal file
68
src/axolotl/integrations/kernels/constants.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Supported MoE block mappings for kernel integrations.
|
||||||
|
|
||||||
|
Maps model_type to the SparseMoeBlock class name(s) in transformers.
|
||||||
|
Used by both ScatterMoE and SonicMoE kernel paths.
|
||||||
|
|
||||||
|
Values can be a single class name (str) or a list of class names for models
|
||||||
|
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
SPARSE_MOE_BLOCK = {
|
||||||
|
# softmax -> topk routing
|
||||||
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
|
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||||
|
"qwen3_next": "Qwen3NextSparseMoeBlock",
|
||||||
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
|
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
|
||||||
|
"qwen3_omni_moe": [
|
||||||
|
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
|
||||||
|
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
|
||||||
|
],
|
||||||
|
"olmoe": "OlmoeSparseMoeBlock",
|
||||||
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"minimax": "MiniMaxSparseMoeBlock",
|
||||||
|
# sigmoid -> topk routing (with group-based expert selection)
|
||||||
|
"glm_moe_dsa": "GlmMoeDsaMoE",
|
||||||
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
|
"glm4_moe": "Glm4MoeMoE",
|
||||||
|
"glm4_moe_lite": "Glm4MoeLiteMoE",
|
||||||
|
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||||
|
# sigmoid -> topk routing (no group selection)
|
||||||
|
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||||
|
# Models below need custom routing (not yet implemented):
|
||||||
|
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||||
|
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||||
|
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
|
||||||
|
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_moe_block_classes(model_type: str):
|
||||||
|
"""Resolve all MoE block classes from transformers for the given model type.
|
||||||
|
|
||||||
|
Returns a list of classes (one for most models, multiple for models with
|
||||||
|
distinct MoE block types like qwen3_omni_moe).
|
||||||
|
"""
|
||||||
|
entry = SPARSE_MOE_BLOCK.get(model_type)
|
||||||
|
if entry is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported MoE model type '{model_type}'. "
|
||||||
|
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_names = entry if isinstance(entry, list) else [entry]
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
classes = []
|
||||||
|
for cls_name in cls_names:
|
||||||
|
moe_cls = getattr(module, cls_name, None)
|
||||||
|
if moe_cls is None:
|
||||||
|
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
|
||||||
|
classes.append(moe_cls)
|
||||||
|
|
||||||
|
return classes
|
||||||
@@ -1,14 +1,59 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from kernels import (
|
import torch
|
||||||
LocalLayerRepository,
|
|
||||||
Mode,
|
|
||||||
register_kernel_mapping,
|
|
||||||
replace_kernel_forward_from_hub,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_sonicmoe_gpu_compat():
|
||||||
|
"""Validate GPU compute capability for SonicMoE and configure env.
|
||||||
|
|
||||||
|
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
|
||||||
|
B300 (sm_103) additionally requires Triton 3.6.0.
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
cc = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
if cc < (9, 0):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
|
||||||
|
f"but detected sm_{cc[0]}{cc[1]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cc > (10, 3):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
|
||||||
|
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Blackwell (sm_100+): enable QuACK GEMM kernels
|
||||||
|
if cc >= (10, 0):
|
||||||
|
os.environ.setdefault("USE_QUACK_GEMM", "1")
|
||||||
|
LOG.info(
|
||||||
|
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# B300 (sm_103): requires Triton 3.6.0
|
||||||
|
if cc == (10, 3):
|
||||||
|
triton_spec = importlib.util.find_spec("triton")
|
||||||
|
if triton_spec is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
|
||||||
|
)
|
||||||
|
import triton
|
||||||
|
|
||||||
|
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
|
||||||
|
if triton_version != (3, 6):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KernelsPlugin(BasePlugin):
|
class KernelsPlugin(BasePlugin):
|
||||||
@@ -19,8 +64,32 @@ class KernelsPlugin(BasePlugin):
|
|||||||
if cfg.use_scattermoe:
|
if cfg.use_scattermoe:
|
||||||
self._register_kernels()
|
self._register_kernels()
|
||||||
self._kernelize_model(cfg.model_config_type)
|
self._kernelize_model(cfg.model_config_type)
|
||||||
|
elif cfg.use_sonicmoe:
|
||||||
|
if not importlib.util.find_spec("sonicmoe"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"SonicMoE is not installed. See installation instructions at "
|
||||||
|
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_sonicmoe_gpu_compat()
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
|
||||||
|
)
|
||||||
|
patch_sonicmoe(
|
||||||
|
cfg.model_config_type,
|
||||||
|
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||||
|
)
|
||||||
|
|
||||||
def _register_kernels(self):
|
def _register_kernels(self):
|
||||||
|
from kernels import (
|
||||||
|
LocalLayerRepository,
|
||||||
|
Mode,
|
||||||
|
register_kernel_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
plugin_root = Path(__file__).parent
|
plugin_root = Path(__file__).parent
|
||||||
register_kernel_mapping(
|
register_kernel_mapping(
|
||||||
{
|
{
|
||||||
@@ -42,25 +111,11 @@ class KernelsPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _kernelize_model(self, model_type: str):
|
def _kernelize_model(self, model_type: str):
|
||||||
if model_type == "olmoe":
|
from kernels import replace_kernel_forward_from_hub
|
||||||
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
|
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||||
|
|
||||||
|
for model_moe_cls in resolve_moe_block_classes(model_type):
|
||||||
replace_kernel_forward_from_hub(
|
replace_kernel_forward_from_hub(
|
||||||
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
|
model_moe_cls, "HFScatterMoEParallelExperts"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_moe_cls = get_model_moe_block(model_type)
|
|
||||||
replace_kernel_forward_from_hub(
|
|
||||||
model_moe_cls, "HFScatterMoEParallelExperts"
|
|
||||||
)
|
|
||||||
except Exception as err:
|
|
||||||
raise ValueError(f"Unsupported model type: {model_type}") from err
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_moe_block(model_type: str):
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
|
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
|
|
||||||
return model_cls
|
|
||||||
|
|||||||
3
src/axolotl/integrations/kernels/sonicmoe/__init__.py
Normal file
3
src/axolotl/integrations/kernels/sonicmoe/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .patch import patch_sonicmoe
|
||||||
|
|
||||||
|
__all__ = ["patch_sonicmoe"]
|
||||||
213
src/axolotl/integrations/kernels/sonicmoe/patch.py
Normal file
213
src/axolotl/integrations/kernels/sonicmoe/patch.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
SonicMoE patching for SparseMoeBlock forward pass.
|
||||||
|
|
||||||
|
Monkeypatches the SparseMoeBlock class for a given model type to use
|
||||||
|
SonicMoE's optimized kernels. Two forward paths are supported:
|
||||||
|
|
||||||
|
1. **General routing path** (routing_fn is not None):
|
||||||
|
Uses a custom routing function + ``moe_general_routing_inputs``.
|
||||||
|
Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).
|
||||||
|
|
||||||
|
2. **Fused topk->softmax path** (routing_fn is None):
|
||||||
|
Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.
|
||||||
|
Suitable for models with simple topk->softmax routing.
|
||||||
|
|
||||||
|
Weight format conversion (interleave/deinterleave) is handled by the
|
||||||
|
WeightConverter system, so the forward assumes weights are already in
|
||||||
|
interleaved format.
|
||||||
|
|
||||||
|
Shared experts are handled generically: if the block has a ``shared_expert``
|
||||||
|
or ``shared_experts`` attribute, its output is computed alongside the routed
|
||||||
|
experts and added to the final output. An optional ``shared_expert_gate``
|
||||||
|
applies sigmoid gating to the shared expert contribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
|
||||||
|
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The HuggingFace model type (e.g. "qwen3_moe").
|
||||||
|
torch_compile: If True, wrap routing functions with torch.compile
|
||||||
|
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
|
||||||
|
"""
|
||||||
|
from .routing import get_model_moe_config
|
||||||
|
from .weight_converter import register_sonicmoe_weight_converter
|
||||||
|
|
||||||
|
routing_fn, activation, router_attr = get_model_moe_config(model_type)
|
||||||
|
|
||||||
|
if torch_compile and routing_fn is not None:
|
||||||
|
routing_fn = _try_compile_routing(routing_fn)
|
||||||
|
|
||||||
|
for moe_cls in resolve_moe_block_classes(model_type):
|
||||||
|
_patch_forward(moe_cls, routing_fn, activation, router_attr)
|
||||||
|
register_sonicmoe_weight_converter(model_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_compile_routing(routing_fn):
|
||||||
|
"""Attempt to torch.compile the routing function, fall back to eager on failure."""
|
||||||
|
try:
|
||||||
|
compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False)
|
||||||
|
LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}")
|
||||||
|
return compiled_fn
|
||||||
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
LOG.warning(
|
||||||
|
f"torch.compile failed for routing function {routing_fn.__name__}, "
|
||||||
|
f"falling back to eager: {exc}"
|
||||||
|
)
|
||||||
|
return routing_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_forward(moe_cls, routing_fn, activation, router_attr):
|
||||||
|
"""Monkeypatch the SparseMoeBlock class with a SonicMoE forward.
|
||||||
|
|
||||||
|
The patched forward handles shared experts generically: if
|
||||||
|
``self.shared_expert`` or ``self.shared_experts`` exists, it is computed
|
||||||
|
and added to the routed output. If ``self.shared_expert_gate`` also exists,
|
||||||
|
it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moe_cls: The SparseMoeBlock class to patch.
|
||||||
|
routing_fn: Routing function (e.g. softmax_topk_routing), or None
|
||||||
|
for the fused moe_TC_softmax_topk_layer path.
|
||||||
|
activation: SonicMoE ActivationType enum value.
|
||||||
|
router_attr: Name of the router module attribute on the MoE block.
|
||||||
|
"""
|
||||||
|
if hasattr(moe_cls, "_original_forward"):
|
||||||
|
LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
original_forward = moe_cls.forward
|
||||||
|
|
||||||
|
if routing_fn is not None:
|
||||||
|
_make_general_forward(moe_cls, routing_fn, activation)
|
||||||
|
else:
|
||||||
|
_make_fused_forward(moe_cls, activation, router_attr)
|
||||||
|
|
||||||
|
moe_cls._original_forward = original_forward
|
||||||
|
LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_general_forward(moe_cls, routing_fn, activation):
|
||||||
|
"""Create forward using routing_fn + moe_general_routing_inputs."""
|
||||||
|
|
||||||
|
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
from sonicmoe import moe_general_routing_inputs
|
||||||
|
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
# Shared expert (computed early, matching original model ordering)
|
||||||
|
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||||
|
|
||||||
|
# Routing
|
||||||
|
router_scores, token_indices, expert_indices, _router_logits = routing_fn(
|
||||||
|
hidden_states_flat, self
|
||||||
|
)
|
||||||
|
|
||||||
|
# Permute weights to SonicMoE layout:
|
||||||
|
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||||
|
# down: [E, H, I] -> [H, I, E]
|
||||||
|
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||||
|
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||||
|
E = gate_up_weight.shape[-1]
|
||||||
|
|
||||||
|
output, _ = moe_general_routing_inputs(
|
||||||
|
hidden_states_flat,
|
||||||
|
router_scores,
|
||||||
|
token_indices,
|
||||||
|
expert_indices,
|
||||||
|
gate_up_weight,
|
||||||
|
None, # b1 (no gate/up bias)
|
||||||
|
down_weight,
|
||||||
|
None, # b2 (no down bias)
|
||||||
|
E,
|
||||||
|
torch.cuda.current_stream().cuda_stream,
|
||||||
|
activation,
|
||||||
|
False, # is_inference_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add shared expert contribution if present
|
||||||
|
if shared_expert_output is not None:
|
||||||
|
if hasattr(self, "shared_expert_gate"):
|
||||||
|
shared_expert_output = (
|
||||||
|
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
|
||||||
|
* shared_expert_output
|
||||||
|
)
|
||||||
|
output = output + shared_expert_output
|
||||||
|
|
||||||
|
return output.view(batch_size, sequence_length, hidden_dim)
|
||||||
|
|
||||||
|
moe_cls.forward = sonicmoe_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fused_forward(moe_cls, activation, router_attr):
|
||||||
|
"""Create forward using moe_TC_softmax_topk_layer (topk -> softmax)."""
|
||||||
|
|
||||||
|
def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
from sonicmoe import moe_TC_softmax_topk_layer
|
||||||
|
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
# Shared expert (computed early, matching original model ordering)
|
||||||
|
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||||
|
|
||||||
|
router = getattr(self, router_attr)
|
||||||
|
|
||||||
|
# Permute weights to SonicMoE layout:
|
||||||
|
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||||
|
# down: [E, H, I] -> [H, I, E]
|
||||||
|
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||||
|
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||||
|
|
||||||
|
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
|
||||||
|
hidden_states_flat,
|
||||||
|
router.weight,
|
||||||
|
gate_up_weight,
|
||||||
|
None, # b1 (no gate/up bias)
|
||||||
|
down_weight,
|
||||||
|
None, # b2 (no down bias)
|
||||||
|
router.top_k,
|
||||||
|
torch.cuda.current_stream().cuda_stream,
|
||||||
|
activation,
|
||||||
|
False, # is_inference_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add shared expert contribution if present
|
||||||
|
if shared_expert_output is not None:
|
||||||
|
if hasattr(self, "shared_expert_gate"):
|
||||||
|
shared_expert_output = (
|
||||||
|
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
|
||||||
|
* shared_expert_output
|
||||||
|
)
|
||||||
|
output = output + shared_expert_output
|
||||||
|
|
||||||
|
return output.view(batch_size, sequence_length, hidden_dim)
|
||||||
|
|
||||||
|
moe_cls.forward = sonicmoe_fused_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_shared_expert(moe_block, hidden_states_flat):
|
||||||
|
"""Compute shared expert output if the block has one.
|
||||||
|
|
||||||
|
Handles singular (qwen2_moe: ``shared_expert``), plural
|
||||||
|
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
|
||||||
|
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
|
||||||
|
"""
|
||||||
|
shared_expert = (
|
||||||
|
getattr(moe_block, "shared_expert", None)
|
||||||
|
or getattr(moe_block, "shared_experts", None)
|
||||||
|
or getattr(moe_block, "shared_mlp", None)
|
||||||
|
)
|
||||||
|
if shared_expert is not None:
|
||||||
|
return shared_expert(hidden_states_flat)
|
||||||
|
return None
|
||||||
219
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
219
src/axolotl/integrations/kernels/sonicmoe/routing.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
Routing functions for SonicMoE integration.
|
||||||
|
|
||||||
|
Different MoE architectures use different routing strategies:
|
||||||
|
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||||
|
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||||
|
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||||
|
|
||||||
|
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||||
|
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_moe_config(model_type: str):
|
||||||
|
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: HuggingFace model type string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
routing_fn: Callable or None. None signals the fused
|
||||||
|
moe_TC_softmax_topk_layer path (topk -> softmax models).
|
||||||
|
activation: SonicMoE ActivationType enum value.
|
||||||
|
router_attr: Name of the router module attribute on the MoE block
|
||||||
|
(e.g. "gate" or "router").
|
||||||
|
|
||||||
|
The activation type cannot be derived from config.hidden_act because
|
||||||
|
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
|
||||||
|
(act_fn(gate) * up pattern). So we specify it per model type.
|
||||||
|
"""
|
||||||
|
from sonicmoe.enums import ActivationType
|
||||||
|
|
||||||
|
if model_type in (
|
||||||
|
"qwen2_moe",
|
||||||
|
"qwen3_moe",
|
||||||
|
"qwen3_5_moe",
|
||||||
|
"qwen3_next",
|
||||||
|
"qwen3_vl_moe",
|
||||||
|
"qwen3_omni_moe",
|
||||||
|
"olmoe",
|
||||||
|
"mixtral",
|
||||||
|
"minimax",
|
||||||
|
):
|
||||||
|
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||||
|
elif model_type in (
|
||||||
|
"glm_moe_dsa",
|
||||||
|
"deepseek_v3",
|
||||||
|
"glm4_moe",
|
||||||
|
"glm4_moe_lite",
|
||||||
|
"glm4v_moe",
|
||||||
|
"minimax_m2",
|
||||||
|
):
|
||||||
|
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
||||||
|
# elif model_type in ("ernie4_5_moe",):
|
||||||
|
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
|
||||||
|
# return ..., ActivationType.SWIGLU, "gate"
|
||||||
|
# elif model_type in ("deepseek_v2",):
|
||||||
|
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
|
||||||
|
# # (not n_group), gate is nn.Linear (not a router class).
|
||||||
|
# return ..., ActivationType.SWIGLU, "gate"
|
||||||
|
# elif model_type in ("hunyuan_v1_moe",):
|
||||||
|
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
|
||||||
|
# # top_k on block not gate, creates scatter routing matrix.
|
||||||
|
# return ..., ActivationType.SWIGLU, "gate"
|
||||||
|
# Fused topk -> softmax path (routing_fn=None):
|
||||||
|
# elif model_type in ("gpt_oss",):
|
||||||
|
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||||
|
# # ignores (it only takes router_w, not bias). Also has transposed
|
||||||
|
# # weight layout [E, H, 2*I] and custom GLU activation.
|
||||||
|
# return None, ActivationType.SWIGLU, "router"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_topk_routing(
|
||||||
|
hidden_states: torch.Tensor, moe_block
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: [T, H] flattened token representations
|
||||||
|
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
router_scores: [T*K] flattened scores (float32)
|
||||||
|
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||||
|
expert_indices: [T*K] which expert (int32)
|
||||||
|
router_logits: [T, E] original logits for aux loss
|
||||||
|
"""
|
||||||
|
gate = moe_block.gate
|
||||||
|
T, H = hidden_states.shape
|
||||||
|
K = gate.top_k
|
||||||
|
|
||||||
|
# Compute router logits and softmax over all experts
|
||||||
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
# Select top-k experts per token
|
||||||
|
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||||
|
|
||||||
|
# Renormalize if configured (default True for models without the attribute,
|
||||||
|
# e.g. Mixtral/MiniMax which always normalize)
|
||||||
|
if getattr(gate, "norm_topk_prob", True):
|
||||||
|
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||||
|
# top_values = top_values.to(router_probs.dtype)
|
||||||
|
|
||||||
|
# Flatten for moe_general_routing_inputs.
|
||||||
|
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||||
|
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||||
|
# Expert sorting is handled internally by general_routing_router_metadata.
|
||||||
|
token_indices = (
|
||||||
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(T, K)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_scores = top_values.reshape(-1) # [T*K]
|
||||||
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||||
|
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||||
|
|
||||||
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_topk_routing(
|
||||||
|
hidden_states: torch.Tensor, moe_block
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
|
||||||
|
|
||||||
|
Supports two variants:
|
||||||
|
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
|
||||||
|
bias on gate, group-based masking before topk.
|
||||||
|
- **No group selection** (minimax_m2): n_group == 1 (or absent),
|
||||||
|
bias on moe_block, straight topk from all experts.
|
||||||
|
|
||||||
|
Final routing weights come from the original sigmoid scores (not
|
||||||
|
bias-corrected), with optional renormalization and scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: [T, H] flattened token representations
|
||||||
|
moe_block: MoE block module (accesses moe_block.gate.* and
|
||||||
|
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
|
||||||
|
.routed_scaling_factor, .n_routed_experts)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
router_scores: [T*K] flattened scores (float32)
|
||||||
|
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||||
|
expert_indices: [T*K] which expert (int32)
|
||||||
|
router_logits: [T, E] original logits for aux loss
|
||||||
|
"""
|
||||||
|
gate = moe_block.gate
|
||||||
|
T, H = hidden_states.shape
|
||||||
|
K = moe_block.top_k
|
||||||
|
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||||
|
n_group = getattr(moe_block, "n_group", 1)
|
||||||
|
|
||||||
|
# Compute router logits and sigmoid probabilities
|
||||||
|
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||||
|
router_probs = router_logits.sigmoid() # [T, E]
|
||||||
|
|
||||||
|
# Bias-corrected scores for expert selection (not used for final weights).
|
||||||
|
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||||
|
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
|
||||||
|
if e_score_correction_bias is None:
|
||||||
|
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||||
|
if e_score_correction_bias is None:
|
||||||
|
raise AttributeError(
|
||||||
|
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||||
|
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||||
|
)
|
||||||
|
scores_for_choice = router_probs + e_score_correction_bias
|
||||||
|
|
||||||
|
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||||
|
if n_group > 1:
|
||||||
|
group_scores = (
|
||||||
|
scores_for_choice.view(-1, n_group, E // n_group)
|
||||||
|
.topk(2, dim=-1)[0]
|
||||||
|
.sum(dim=-1)
|
||||||
|
) # [T, n_group]
|
||||||
|
group_idx = torch.topk(
|
||||||
|
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||||
|
)[1]
|
||||||
|
group_mask = torch.zeros_like(group_scores)
|
||||||
|
group_mask.scatter_(1, group_idx, 1)
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
|
)
|
||||||
|
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||||
|
|
||||||
|
# Final topk from (possibly masked) scores
|
||||||
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
|
|
||||||
|
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||||
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Optional renormalization + scaling
|
||||||
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
|
if norm_topk_prob:
|
||||||
|
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||||
|
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||||
|
topk_weights = topk_weights * routed_scaling_factor
|
||||||
|
|
||||||
|
# Flatten for moe_general_routing_inputs.
|
||||||
|
# Token indices are naturally sorted ascending from the [T, K] layout.
|
||||||
|
token_indices = (
|
||||||
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(T, K)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||||
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||||
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||||
|
|
||||||
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
181
src/axolotl/integrations/kernels/sonicmoe/weight_converter.py
Normal file
181
src/axolotl/integrations/kernels/sonicmoe/weight_converter.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Custom WeightConverter operations for SonicMoE weight format conversion.
|
||||||
|
|
||||||
|
SonicMoE requires gate_up_proj weights in interleaved format:
|
||||||
|
- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up
|
||||||
|
- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]
|
||||||
|
|
||||||
|
These ConversionOps integrate with transformers' WeightConverter system so that
|
||||||
|
weights are transparently converted during loading and reverted during saving.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from transformers.core_model_loading import ConversionOps
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension."""
|
||||||
|
return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2)
|
||||||
|
|
||||||
|
|
||||||
|
def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension."""
|
||||||
|
return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2)
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatenatedToInterleaved(ConversionOps):
|
||||||
|
"""Convert concatenated gate/up projections to interleaved format.
|
||||||
|
|
||||||
|
Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
|
||||||
|
Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
|
||||||
|
|
||||||
|
This operation is applied along ``dim`` (default 1, the 2*I dimension).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1):
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(
|
||||||
|
self,
|
||||||
|
input_dict: dict[str, Any],
|
||||||
|
source_patterns: list[str],
|
||||||
|
target_patterns: list[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
target_pattern = self._get_target_pattern(
|
||||||
|
input_dict, source_patterns, target_patterns
|
||||||
|
)
|
||||||
|
tensors = next(iter(input_dict.values()))
|
||||||
|
tensor = tensors[0] if isinstance(tensors, list) else tensors
|
||||||
|
|
||||||
|
interleaved = interleave_gate_up(tensor)
|
||||||
|
|
||||||
|
return {target_pattern: interleaved}
|
||||||
|
|
||||||
|
def _get_target_pattern(
|
||||||
|
self,
|
||||||
|
input_dict: dict[str, Any],
|
||||||
|
source_patterns: list[str],
|
||||||
|
target_patterns: list[str],
|
||||||
|
) -> str:
|
||||||
|
# Follow the same logic as Transpose.get_target_pattern
|
||||||
|
if len(input_dict) != 1:
|
||||||
|
raise ValueError("Undefined Operation encountered!")
|
||||||
|
if len(target_patterns) > 1:
|
||||||
|
if len(source_patterns) == 1:
|
||||||
|
return source_patterns[0]
|
||||||
|
raise ValueError("Undefined Operation encountered!")
|
||||||
|
return target_patterns[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reverse_op(self) -> ConversionOps:
|
||||||
|
return InterleavedToConcatenated(self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class InterleavedToConcatenated(ConversionOps):
|
||||||
|
"""Convert interleaved gate/up projections back to concatenated format.
|
||||||
|
|
||||||
|
Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
|
||||||
|
Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
|
||||||
|
|
||||||
|
This is the reverse of ``ConcatenatedToInterleaved``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1):
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert(
|
||||||
|
self,
|
||||||
|
input_dict: dict[str, Any],
|
||||||
|
source_patterns: list[str],
|
||||||
|
target_patterns: list[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
target_pattern = self._get_target_pattern(
|
||||||
|
input_dict, source_patterns, target_patterns
|
||||||
|
)
|
||||||
|
tensors = next(iter(input_dict.values()))
|
||||||
|
tensor = tensors[0] if isinstance(tensors, list) else tensors
|
||||||
|
|
||||||
|
concatenated = deinterleave_gate_up(tensor)
|
||||||
|
|
||||||
|
return {target_pattern: concatenated}
|
||||||
|
|
||||||
|
def _get_target_pattern(
|
||||||
|
self,
|
||||||
|
input_dict: dict[str, Any],
|
||||||
|
source_patterns: list[str],
|
||||||
|
target_patterns: list[str],
|
||||||
|
) -> str:
|
||||||
|
if len(input_dict) != 1:
|
||||||
|
raise ValueError("Undefined Operation encountered!")
|
||||||
|
if len(target_patterns) > 1:
|
||||||
|
if len(source_patterns) == 1:
|
||||||
|
return source_patterns[0]
|
||||||
|
raise ValueError("Undefined Operation encountered!")
|
||||||
|
return target_patterns[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reverse_op(self) -> ConversionOps:
|
||||||
|
return ConcatenatedToInterleaved(self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
def register_sonicmoe_weight_converter(model_type: str):
|
||||||
|
"""Override the conversion mapping to add interleave step for gate_up_proj.
|
||||||
|
|
||||||
|
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
|
||||||
|
converter chain. For example, qwen3_moe's chain becomes:
|
||||||
|
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
|
||||||
|
|
||||||
|
The reverse is auto-generated for saving:
|
||||||
|
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
|
||||||
|
"""
|
||||||
|
from transformers.conversion_mapping import (
|
||||||
|
get_checkpoint_conversion_mapping,
|
||||||
|
register_checkpoint_conversion_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = get_checkpoint_conversion_mapping(model_type)
|
||||||
|
if existing is None:
|
||||||
|
LOG.warning(
|
||||||
|
f"No conversion mapping found for model type '{model_type}'. "
|
||||||
|
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
|
||||||
|
patched = False
|
||||||
|
for converter in existing:
|
||||||
|
if hasattr(converter, "operations") and any(
|
||||||
|
"gate_up_proj" in pat for pat in converter.target_patterns
|
||||||
|
):
|
||||||
|
# Guard against double registration (e.g. plugin reloaded)
|
||||||
|
if any(
|
||||||
|
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
f"SonicMoE weight converter already registered for '{model_type}'"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||||
|
patched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not patched:
|
||||||
|
LOG.warning(
|
||||||
|
f"Could not find gate_up_proj converter for model type '{model_type}'. "
|
||||||
|
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
|
||||||
|
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||||
@@ -8,9 +8,6 @@ import sys
|
|||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .models.base import patch_lce_forward
|
|
||||||
from .utils import patch_with_compile_disable
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -23,10 +20,18 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
# shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path
|
||||||
|
import trl.trainer
|
||||||
|
from trl.experimental.orpo import ORPOTrainer
|
||||||
|
|
||||||
|
trl.trainer.ORPOTrainer = ORPOTrainer
|
||||||
|
|
||||||
if cfg.torch_compile:
|
if cfg.torch_compile:
|
||||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||||
import liger_kernel.ops.fused_linear_cross_entropy
|
import liger_kernel.ops.fused_linear_cross_entropy
|
||||||
|
|
||||||
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
patch_with_compile_disable(
|
patch_with_compile_disable(
|
||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
"fused_linear_cross_entropy_forward",
|
"fused_linear_cross_entropy_forward",
|
||||||
@@ -35,6 +40,7 @@ class LigerPlugin(BasePlugin):
|
|||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
"fused_linear_cross_entropy_backward",
|
"fused_linear_cross_entropy_backward",
|
||||||
)
|
)
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
@@ -192,6 +198,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
elif cfg.liger_fused_linear_cross_entropy:
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
try:
|
try:
|
||||||
|
from .models.base import patch_lce_forward
|
||||||
|
|
||||||
patch_lce_forward(cfg.model_config_type)
|
patch_lce_forward(cfg.model_config_type)
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
|
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if isinstance(param, Params4bit):
|
if isinstance(param, Params4bit) and param.quant_state is not None:
|
||||||
param.quant_state._orig_to = param.quant_state.to
|
param.quant_state._orig_to = param.quant_state.to
|
||||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,10 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
|
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
self.patch_manager.apply_post_model_build_patches(self.model)
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -671,8 +674,8 @@ class ModelLoader:
|
|||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
|
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
|
||||||
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
|
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (
|
||||||
lambda: True
|
True
|
||||||
)
|
)
|
||||||
|
|
||||||
return hf_ds_cfg
|
return hf_ds_cfg
|
||||||
@@ -860,6 +863,10 @@ class ModelLoader:
|
|||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if getattr(self.model, "_moe_experts_quantized", False):
|
||||||
|
# Parametrized expert tensors dequantize on access — would OOM.
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class PatchManager:
|
|||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
self._apply_moe_expert_quantization_patch()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
@@ -135,6 +136,10 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
|
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||||
|
"""Apply patches right after model build, before post-load setup."""
|
||||||
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
@@ -161,6 +166,13 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
|
if self.cfg.fsdp_config:
|
||||||
|
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||||
|
patch_initialize_missing_keys_for_fsdp,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_initialize_missing_keys_for_fsdp()
|
||||||
|
|
||||||
if self.cfg.context_parallel_size > 1 or (
|
if self.cfg.context_parallel_size > 1 or (
|
||||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||||
):
|
):
|
||||||
@@ -170,9 +182,14 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_parallelism_config()
|
patch_parallelism_config()
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||||
|
patch_accelerate_fsdp2,
|
||||||
|
patch_tied_keys_for_meta_device,
|
||||||
|
)
|
||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
|
patch_tied_keys_for_meta_device()
|
||||||
if self.cfg.rl:
|
if self.cfg.rl:
|
||||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||||
|
|
||||||
@@ -229,6 +246,31 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_qwen3_next_modeling_packing()
|
patch_qwen3_next_modeling_packing()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
||||||
|
patch_qwen3_5_modeling_packing,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_qwen3_5_modeling_packing()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
||||||
|
patch_qwen3_5_moe_modeling_packing,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_qwen3_5_moe_modeling_packing()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
|
||||||
|
and self.cfg.is_multimodal
|
||||||
|
and self.cfg.flash_attention
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
||||||
|
patch_qwen3_5_vlm_flash_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_qwen3_5_vlm_flash_attention()
|
||||||
|
|
||||||
if self.cfg.model_config_type == "kimi_linear":
|
if self.cfg.model_config_type == "kimi_linear":
|
||||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
patch_kimi_model,
|
patch_kimi_model,
|
||||||
@@ -352,15 +394,54 @@ class PatchManager:
|
|||||||
if (
|
if (
|
||||||
self.cfg.fsdp_config
|
self.cfg.fsdp_config
|
||||||
and str(self.cfg.fsdp_version) == "2"
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
and self.cfg.adapter == "qlora"
|
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_init_dtype_attrs_patch,
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
apply_init_unsharded_param_patch,
|
apply_init_unsharded_param_patch,
|
||||||
|
apply_linear8bitlt_save_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_init_sharded_param_patch()
|
apply_init_sharded_param_patch()
|
||||||
apply_init_unsharded_param_patch()
|
apply_init_unsharded_param_patch()
|
||||||
|
apply_init_dtype_attrs_patch()
|
||||||
|
if self.cfg.load_in_8bit:
|
||||||
|
apply_linear8bitlt_save_patch()
|
||||||
|
|
||||||
|
def _apply_moe_expert_quantization_patch(self):
|
||||||
|
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||||
|
if not self.cfg.quantize_moe_experts:
|
||||||
|
return
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
patch_moe_quantization_on_load,
|
||||||
|
patch_peft_target_parameters_matching,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_moe_quantization_on_load(self.cfg)
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
|
||||||
|
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||||
|
"""Log quantization results and set model flag for downstream use."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model._moe_experts_quantized = False
|
||||||
|
if self.cfg.quantize_moe_experts:
|
||||||
|
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
|
||||||
|
|
||||||
|
count = get_moe_quantized_count()
|
||||||
|
if count > 0:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
model._moe_experts_quantized = True
|
||||||
|
LOG.info(
|
||||||
|
"Quantized %d MoE expert parameter(s) to %s during model loading",
|
||||||
|
count,
|
||||||
|
"4-bit" if self.cfg.load_in_4bit else "8-bit",
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # nosec B105
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# Mistral's official FA implementation requires left padding
|
# Mistral's official FA implementation requires left padding
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
|||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|||||||
@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
)
|
)
|
||||||
elif self.is_fsdp2:
|
elif self.is_fsdp2:
|
||||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
sharded_state_dict = model.state_dict()
|
sharded_state_dict = model.state_dict()
|
||||||
for param_name, param in sharded_state_dict.items():
|
for param_name, param in sharded_state_dict.items():
|
||||||
if param.is_cpu:
|
if param.is_cpu:
|
||||||
param = param.to(torch.device("cuda"))
|
param = param.to(torch.device("cuda"))
|
||||||
|
|
||||||
param = param.full_tensor()
|
if isinstance(param, DTensor):
|
||||||
|
param = param.full_tensor()
|
||||||
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
state_dict[param_name] = param.cpu()
|
state_dict[param_name] = param.cpu()
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_param_wrapper_for_fsdp2():
|
||||||
|
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||||
|
|
||||||
|
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||||
|
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||||
|
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||||
|
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||||
|
|
||||||
|
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||||
|
using DTensor.from_local(), which is free for Replicate placement (just
|
||||||
|
metadata wrapping, no communication).
|
||||||
|
"""
|
||||||
|
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||||
|
|
||||||
|
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_forward = _LoraParameterProxy.forward
|
||||||
|
|
||||||
|
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
|
||||||
|
def _patched_forward(self, W):
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
delta = self.delta_weight
|
||||||
|
w_is_dt = isinstance(W, DTensor)
|
||||||
|
d_is_dt = isinstance(delta, DTensor)
|
||||||
|
|
||||||
|
with torch.nn.utils.parametrize.cached():
|
||||||
|
if w_is_dt == d_is_dt:
|
||||||
|
return W + delta
|
||||||
|
if w_is_dt:
|
||||||
|
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||||
|
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||||
|
|
||||||
|
_LoraParameterProxy.forward = _patched_forward
|
||||||
|
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||||
|
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||||
|
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
|
||||||
|
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
|
||||||
|
# The parent decoder layer's FSDP wrapper handles unsharding them.
|
||||||
|
# TODO: review if we even need to shard them separately in first place.
|
||||||
|
if isinstance(module, ParamWrapper):
|
||||||
|
return False
|
||||||
|
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
@@ -202,12 +252,20 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
|||||||
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||||
if module.lora_B:
|
if module.lora_B:
|
||||||
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||||
if module.lora_embedding_A:
|
|
||||||
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_embedding_B:
|
|
||||||
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
|
||||||
if module.lora_magnitude_vector:
|
if module.lora_magnitude_vector:
|
||||||
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||||
|
|
||||||
|
# lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors),
|
||||||
|
# not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard
|
||||||
|
# individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not
|
||||||
|
# override groups already assigned by fully_shard(), so modules
|
||||||
|
# where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html]
|
||||||
|
if module.lora_embedding_A or module.lora_embedding_B:
|
||||||
|
from torch.distributed.fsdp import FSDPModule
|
||||||
|
|
||||||
|
if not isinstance(module, FSDPModule):
|
||||||
|
fully_shard(module, **fsdp2_kwargs)
|
||||||
|
|
||||||
return log_bias_dtype_mismatch
|
return log_bias_dtype_mismatch
|
||||||
|
|
||||||
|
|
||||||
@@ -327,6 +385,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
|
||||||
|
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||||
|
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||||
|
if is_peft_model:
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
|
|
||||||
|
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||||
|
patch_peft_param_wrapper_for_fsdp2()
|
||||||
|
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
@@ -376,6 +442,83 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def patch_tied_keys_for_meta_device():
|
||||||
|
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
|
||||||
|
|
||||||
|
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
|
||||||
|
grouped as "tied". Skipping them is safe since they have no real storage.
|
||||||
|
"""
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
|
||||||
|
param_pointers = defaultdict(list)
|
||||||
|
for param_name, param_value in self.state_dict().items():
|
||||||
|
if param_value.is_meta:
|
||||||
|
continue
|
||||||
|
param_pointers[param_value.data_ptr()].append(param_name)
|
||||||
|
|
||||||
|
tied_param_names = [
|
||||||
|
names
|
||||||
|
for names in param_pointers.values()
|
||||||
|
if len(names) > 1
|
||||||
|
and not any(name in self.all_tied_weights_keys.keys() for name in names)
|
||||||
|
and not all(name in missing_keys for name in names)
|
||||||
|
]
|
||||||
|
|
||||||
|
tied_weights_keys_by_pointers = {
|
||||||
|
param_name: group[0]
|
||||||
|
for group in tied_param_names
|
||||||
|
for param_name in group[1:]
|
||||||
|
}
|
||||||
|
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
|
||||||
|
|
||||||
|
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
|
||||||
|
_patched_adjust_tied_keys_with_tied_pointers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_initialize_missing_keys_for_fsdp():
|
||||||
|
"""Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
|
||||||
|
|
||||||
|
When using cpu_ram_efficient_loading, non-rank-0 processes load weights on
|
||||||
|
meta device and move them to CPU as empty tensors. Without this patch,
|
||||||
|
initialize_weights() re-initializes ALL parameters (via guarded init
|
||||||
|
functions), which is slow and uses extra RAM per process.
|
||||||
|
|
||||||
|
The fix marks all params/buffers with _is_hf_initialized=True before calling
|
||||||
|
the original method, so guarded init functions (init.normal_, init.zeros_,
|
||||||
|
etc.) become no-ops on non-rank-0 processes. The real weights arrive later
|
||||||
|
via FSDP broadcast from rank 0.
|
||||||
|
|
||||||
|
Upstream fix: https://github.com/huggingface/transformers/pull/44473
|
||||||
|
Remove this patch once transformers includes the fix in a stable release.
|
||||||
|
"""
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
|
||||||
|
|
||||||
|
if getattr(PreTrainedModel._initialize_missing_keys, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
|
||||||
|
|
||||||
|
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
|
||||||
|
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||||
|
for key in self.state_dict():
|
||||||
|
try:
|
||||||
|
param_or_buffer = self.get_parameter_or_buffer(key)
|
||||||
|
param_or_buffer._is_hf_initialized = True
|
||||||
|
except AttributeError:
|
||||||
|
pass # may happen when handling pre-quantized weights
|
||||||
|
self._is_hf_initialized = True
|
||||||
|
|
||||||
|
_original_initialize_missing_keys(self, is_quantized)
|
||||||
|
|
||||||
|
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
|
||||||
|
PreTrainedModel._initialize_missing_keys._axolotl_patched = True
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
def patch_accelerate_fsdp2():
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
|
||||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||||
|
|
||||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
|
||||||
Params4bit parameters.
|
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
|
||||||
|
metadata through the FSDP2 shard/unshard cycle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
def apply_init_sharded_param_patch():
|
def apply_init_sharded_param_patch():
|
||||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
|
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
|
|||||||
bnb_quantized=param.bnb_quantized,
|
bnb_quantized=param.bnb_quantized,
|
||||||
)
|
)
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
|
elif isinstance(param, bnb.nn.modules.Int8Params):
|
||||||
|
self.sharded_param = bnb.nn.modules.Int8Params(
|
||||||
|
data=sharded_param,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
has_fp16_weights=param.has_fp16_weights,
|
||||||
|
CB=None,
|
||||||
|
SCB=param.SCB,
|
||||||
|
)
|
||||||
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
else:
|
else:
|
||||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
self.sharded_param = nn.Parameter(
|
||||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
self.to_sharded_dtensor(sharded_param),
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
)"""
|
||||||
|
|
||||||
# Apply the replacement
|
# Apply the replacement
|
||||||
if original_param_creation in original_source:
|
if original_param_creation in original_source:
|
||||||
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam._init_sharded_param = patched_init_sharded_param
|
FSDPParam._init_sharded_param = patched_init_sharded_param
|
||||||
|
apply_init_sharded_param_patch._axolotl_patched = True
|
||||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||||
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
def apply_init_unsharded_param_patch():
|
def apply_init_unsharded_param_patch():
|
||||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||||
|
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
|
|||||||
module=local_tensor.module,
|
module=local_tensor.module,
|
||||||
bnb_quantized=local_tensor.bnb_quantized,
|
bnb_quantized=local_tensor.bnb_quantized,
|
||||||
)
|
)
|
||||||
|
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
|
||||||
|
self._unsharded_param = bnb.nn.modules.Int8Params(
|
||||||
|
data=unsharded_param,
|
||||||
|
requires_grad=self.sharded_param.requires_grad,
|
||||||
|
has_fp16_weights=local_tensor.has_fp16_weights,
|
||||||
|
CB=unsharded_param,
|
||||||
|
SCB=local_tensor.SCB,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._unsharded_param = nn.Parameter(
|
self._unsharded_param = nn.Parameter(
|
||||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
||||||
|
apply_init_unsharded_param_patch._axolotl_patched = True
|
||||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for patching")
|
LOG.warning("Could not find target code for patching")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_linear8bitlt_save_patch():
|
||||||
|
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
|
||||||
|
|
||||||
|
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
|
||||||
|
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
|
||||||
|
doesn't proxy custom attribute access to its _local_tensor. This patch
|
||||||
|
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
|
||||||
|
"""
|
||||||
|
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
|
||||||
|
|
||||||
|
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
|
||||||
|
weight = self._parameters["weight"]
|
||||||
|
unwrapped = False
|
||||||
|
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
|
||||||
|
self._parameters["weight"] = weight._local_tensor
|
||||||
|
unwrapped = True
|
||||||
|
try:
|
||||||
|
original_save(self, destination, prefix, keep_vars)
|
||||||
|
finally:
|
||||||
|
if unwrapped:
|
||||||
|
self._parameters["weight"] = weight
|
||||||
|
|
||||||
|
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
|
||||||
|
apply_linear8bitlt_save_patch._axolotl_patched = True
|
||||||
|
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_init_dtype_attrs_patch():
|
||||||
|
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
|
||||||
|
|
||||||
|
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
|
||||||
|
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
|
||||||
|
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
|
||||||
|
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
|
||||||
|
|
||||||
|
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
|
||||||
|
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
|
||||||
|
without extensions.
|
||||||
|
"""
|
||||||
|
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
|
||||||
|
|
||||||
|
def patched_init_dtype_attrs(self, mp_policy):
|
||||||
|
original_init_dtype_attrs(self, mp_policy)
|
||||||
|
# Skip casting non-float quantized params (uint8/int8) without FSDP2
|
||||||
|
# extensions — the parametrization chain handles dequantization.
|
||||||
|
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
|
||||||
|
local = self.sharded_param
|
||||||
|
if hasattr(local, "_local_tensor"):
|
||||||
|
local = local._local_tensor
|
||||||
|
if not hasattr(local, "fsdp_pre_all_gather"):
|
||||||
|
self.param_dtype = None
|
||||||
|
|
||||||
|
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
|
||||||
|
apply_init_dtype_attrs_patch._axolotl_patched = True
|
||||||
|
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/qwen3_5/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/qwen3_5/__init__.py
Normal file
291
src/axolotl/monkeypatch/models/qwen3_5/modeling.py
Normal file
291
src/axolotl/monkeypatch/models/qwen3_5/modeling.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""Monkeypatch for Qwen3_5 and Qwen3_5Moe models to pass position_ids to linear attention."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fla.modules.convolution import (
|
||||||
|
causal_conv1d as fla_causal_conv1d, # FLA >= 0.4.1
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from fla.modules.conv import causal_conv1d as fla_causal_conv1d # FLA < 0.4.1
|
||||||
|
except ImportError:
|
||||||
|
fla_causal_conv1d = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens(position_ids):
|
||||||
|
"""
|
||||||
|
Compute cumulative sequence lengths from position_ids for FLA varlen kernels.
|
||||||
|
|
||||||
|
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
|
||||||
|
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
|
||||||
|
|
||||||
|
Qwen3.5 uses MRoPE: position_ids arrive as [axes, B, T]. All axes carry the
|
||||||
|
same temporal positions, so axis 0 is used to recover the [B, T] layout.
|
||||||
|
See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
|
||||||
|
"""
|
||||||
|
if position_ids.ndim == 3:
|
||||||
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
|
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||||
|
position_ids = position_ids.view(-1)
|
||||||
|
indices_q = (position_ids == 0).nonzero().view(-1)
|
||||||
|
return torch.cat(
|
||||||
|
(
|
||||||
|
indices_q.to(**tensor_kwargs),
|
||||||
|
torch.tensor(position_ids.size(), **tensor_kwargs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _inject_fla_kernels(module) -> None:
|
||||||
|
"""Inject FLA kernels into a modeling module, bypassing is_flash_linear_attention_available."""
|
||||||
|
try:
|
||||||
|
from fla.modules import FusedRMSNormGated
|
||||||
|
from fla.ops.gated_delta_rule import (
|
||||||
|
chunk_gated_delta_rule,
|
||||||
|
fused_recurrent_gated_delta_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
module.FusedRMSNormGated = FusedRMSNormGated
|
||||||
|
module.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||||
|
module.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule
|
||||||
|
module.is_fast_path_available = True
|
||||||
|
except ImportError:
|
||||||
|
module.chunk_gated_delta_rule = None
|
||||||
|
module.fused_recurrent_gated_delta_rule = None
|
||||||
|
module.FusedRMSNormGated = None
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_decoder_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values=None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Decoder layer forward that passes position_ids through to linear attention."""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
hidden_states = self.linear_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
cache_params=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
elif self.layer_type == "full_attention":
|
||||||
|
hidden_states, _ = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
|
||||||
|
hidden_states, _ = hidden_states
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _make_qwen3_5_gated_delta_forward(apply_mask_fn):
|
||||||
|
"""Factory for patched Qwen3_5/Qwen3_5Moe GatedDeltaNet forward with packing support."""
|
||||||
|
|
||||||
|
def patched_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params=None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = apply_mask_fn(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
use_precomputed_states = (
|
||||||
|
cache_params is not None
|
||||||
|
and cache_params.has_previous_state
|
||||||
|
and seq_len == 1
|
||||||
|
and cache_position is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
if not use_precomputed_states and position_ids is not None:
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||||
|
|
||||||
|
if cache_params is not None:
|
||||||
|
conv_state = cache_params.conv_states[self.layer_idx]
|
||||||
|
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
||||||
|
|
||||||
|
# mixed_qkv stays [B, T, D]; only transposed inside paths that require [B, D, T]
|
||||||
|
mixed_qkv = self.in_proj_qkv(hidden_states) # [B, T, D]
|
||||||
|
|
||||||
|
z = self.in_proj_z(hidden_states)
|
||||||
|
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||||
|
|
||||||
|
b = self.in_proj_b(hidden_states)
|
||||||
|
a = self.in_proj_a(hidden_states)
|
||||||
|
|
||||||
|
if use_precomputed_states:
|
||||||
|
mixed_qkv = self.causal_conv1d_update(
|
||||||
|
mixed_qkv.transpose(1, 2),
|
||||||
|
conv_state,
|
||||||
|
self.conv1d.weight.squeeze(1),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
if cache_params is not None:
|
||||||
|
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||||
|
cache_params.conv_states[self.layer_idx] = F.pad(
|
||||||
|
mixed_qkv_t,
|
||||||
|
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if fla_causal_conv1d is not None and cu_seqlens is not None:
|
||||||
|
# FLA varlen kernel for packed sequences; input must be contiguous [B, T, D]
|
||||||
|
mixed_qkv, _ = fla_causal_conv1d(
|
||||||
|
x=mixed_qkv,
|
||||||
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if cu_seqlens is not None and fla_causal_conv1d is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
||||||
|
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
||||||
|
)
|
||||||
|
mixed_qkv = F.silu(
|
||||||
|
self.conv1d(mixed_qkv.transpose(1, 2))[:, :, :seq_len]
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
query, key, value = torch.split(
|
||||||
|
mixed_qkv,
|
||||||
|
[self.key_dim, self.key_dim, self.value_dim],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||||
|
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||||
|
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||||
|
|
||||||
|
beta = b.sigmoid()
|
||||||
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||||
|
if self.num_v_heads // self.num_k_heads > 1:
|
||||||
|
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
|
||||||
|
if not use_precomputed_states:
|
||||||
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
g=g.to(dtype=query.dtype),
|
||||||
|
beta=beta,
|
||||||
|
initial_state=None,
|
||||||
|
output_final_state=cache_params is not None,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
# torch_chunk_gated_delta_rule fallback does not accept cu_seqlens
|
||||||
|
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
g=g.to(dtype=query.dtype),
|
||||||
|
beta=beta,
|
||||||
|
initial_state=recurrent_state,
|
||||||
|
output_final_state=cache_params is not None,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_params is not None:
|
||||||
|
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
||||||
|
|
||||||
|
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
|
||||||
|
z = z.reshape(-1, self.head_v_dim)
|
||||||
|
core_attn_out = self.norm(core_attn_out, z)
|
||||||
|
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
|
||||||
|
|
||||||
|
return self.out_proj(core_attn_out)
|
||||||
|
|
||||||
|
return patched_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_packing_patches(model_type: str, cls_prefix: str, forward_factory) -> None:
|
||||||
|
module_name = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(f"{model_type} not found in transformers, skipping packing patches")
|
||||||
|
return
|
||||||
|
|
||||||
|
_inject_fla_kernels(module)
|
||||||
|
getattr(module, f"{cls_prefix}DecoderLayer").forward = _patched_decoder_forward
|
||||||
|
gated_cls = getattr(module, f"{cls_prefix}GatedDeltaNet")
|
||||||
|
gated_cls.forward = forward_factory(module.apply_mask_to_padding_states)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Applied {cls_prefix} packing patch "
|
||||||
|
f"(fla_causal_conv1d={'available' if fla_causal_conv1d else 'unavailable'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_5_modeling_packing():
|
||||||
|
_apply_packing_patches("qwen3_5", "Qwen3_5", _make_qwen3_5_gated_delta_forward)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_5_moe_modeling_packing():
|
||||||
|
_apply_packing_patches(
|
||||||
|
"qwen3_5_moe", "Qwen3_5Moe", _make_qwen3_5_gated_delta_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_5_vlm_flash_attention():
|
||||||
|
"""
|
||||||
|
Patch _is_packed_sequence to handle Qwen3.5's 3-D MRoPE position_ids.
|
||||||
|
|
||||||
|
transformers passes position_ids as [axes, B, T] to decoder layers, but
|
||||||
|
_is_packed_sequence only handles 2-D tensors and mis-classifies the 3-D
|
||||||
|
shape as a packed-sequence indicator, causing CUDA errors in the varlen path.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import transformers.modeling_flash_attention_utils as fa_utils
|
||||||
|
|
||||||
|
_original = fa_utils._is_packed_sequence
|
||||||
|
|
||||||
|
def _patched(position_ids, batch_size):
|
||||||
|
if position_ids is not None and position_ids.ndim != 2:
|
||||||
|
return False
|
||||||
|
return _original(position_ids, batch_size)
|
||||||
|
|
||||||
|
fa_utils._is_packed_sequence = _patched
|
||||||
|
LOG.info("Applied Qwen3.5 VLM flash-attention patch (3-D MRoPE position_ids)")
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
LOG.warning(f"Failed to apply Qwen3.5 VLM flash-attention patch: {exc}")
|
||||||
@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||||
|
except ImportError:
|
||||||
|
fla_causal_conv1d = None
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens(position_ids):
|
def get_cu_seqlens(position_ids):
|
||||||
"""
|
"""
|
||||||
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
and cache_position is not None
|
and cache_position is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
||||||
|
cu_seqlens = None
|
||||||
|
if not use_precomputed_states and position_ids is not None:
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||||
|
|
||||||
# getting projected states from cache if it exists
|
# getting projected states from cache if it exists
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
conv_state = cache_params.conv_states[self.layer_idx]
|
conv_state = cache_params.conv_states[self.layer_idx]
|
||||||
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
||||||
|
|
||||||
if use_precomputed_states:
|
if use_precomputed_states:
|
||||||
# 2. Convolution sequence transformation
|
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
||||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
mixed_qkv = self.causal_conv1d_update(
|
mixed_qkv = self.causal_conv1d_update(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
conv_state,
|
conv_state,
|
||||||
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
)
|
)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
|
# Cache state expects [B, D, T] for the inference update path
|
||||||
|
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||||
conv_state = F.pad(
|
conv_state = F.pad(
|
||||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
mixed_qkv_t,
|
||||||
|
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||||
)
|
)
|
||||||
cache_params.conv_states[self.layer_idx] = conv_state
|
cache_params.conv_states[self.layer_idx] = conv_state
|
||||||
if self.causal_conv1d_fn is not None:
|
|
||||||
mixed_qkv = self.causal_conv1d_fn(
|
if fla_causal_conv1d is not None:
|
||||||
|
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
||||||
|
mixed_qkv, _ = fla_causal_conv1d(
|
||||||
x=mixed_qkv,
|
x=mixed_qkv,
|
||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
seq_idx=None,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# PyTorch fallback (no cu_seqlens support)
|
||||||
|
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
||||||
|
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
||||||
|
)
|
||||||
|
LOG.warning_once(
|
||||||
|
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
||||||
|
)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
|
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
# mixed_qkv is [B, T, D] in all paths
|
||||||
query, key, value = torch.split(
|
query, key, value = torch.split(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
[
|
[
|
||||||
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
|
||||||
if not use_precomputed_states:
|
if not use_precomputed_states:
|
||||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
|
||||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||||
|
|
||||||
|
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||||
|
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||||
|
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||||
|
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||||
|
"""
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
import torch.nn.utils.parametrize as P
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
# Module-level state for the loading-time quantization patch.
|
||||||
|
_moe_load_state = {
|
||||||
|
"count": 0,
|
||||||
|
"mode": "4bit",
|
||||||
|
"quant_type": "nf4",
|
||||||
|
"compress_statistics": True,
|
||||||
|
"patched": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Bnb8bitParametrization(torch.nn.Module):
|
||||||
|
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||||
|
|
||||||
|
def __init__(self, row_stats: torch.Tensor):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("row_stats", row_stats)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||||
|
orig_shape = quantized_param.shape
|
||||||
|
if quantized_param.ndim > 2:
|
||||||
|
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||||
|
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||||
|
return result.reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_parametrization_cache(module, inputs):
|
||||||
|
P._cache_enabled += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_parametrization_cache(module, inputs, output):
|
||||||
|
P._cache_enabled -= 1
|
||||||
|
if not P._cache_enabled:
|
||||||
|
P._cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_parameter_8bit(module, param_name):
|
||||||
|
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
|
||||||
|
original_param = getattr(module, param_name)
|
||||||
|
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||||
|
original_param.data.to(torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||||
|
del original_param
|
||||||
|
|
||||||
|
P.register_parametrization(
|
||||||
|
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache dequantized values during forward to avoid redundant dequantization.
|
||||||
|
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
|
||||||
|
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||||
|
module.register_forward_hook(_disable_parametrization_cache)
|
||||||
|
module._axolotl_8bit_hooks_registered = True
|
||||||
|
|
||||||
|
|
||||||
|
def patch_moe_quantization_on_load(cfg):
|
||||||
|
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||||
|
|
||||||
|
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||||
|
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||||
|
"""
|
||||||
|
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||||
|
_moe_load_state["mode"] = mode
|
||||||
|
_moe_load_state["count"] = 0
|
||||||
|
|
||||||
|
if _moe_load_state["patched"]:
|
||||||
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
|
return
|
||||||
|
|
||||||
|
import transformers.core_model_loading
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
if mode == "4bit":
|
||||||
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
|
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||||
|
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = True
|
||||||
|
|
||||||
|
_moe_load_state["quant_type"] = quant_type
|
||||||
|
_moe_load_state["compress_statistics"] = compress_statistics
|
||||||
|
|
||||||
|
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
||||||
|
# size for all params, defeating our on-load quantization VRAM savings.
|
||||||
|
def _noop_warmup(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||||
|
|
||||||
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
|
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||||
|
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||||
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
|
if "expert" not in target_name.lower():
|
||||||
|
LOG.debug(
|
||||||
|
"Skipping non-expert 3D param: %s (shape=%s)",
|
||||||
|
target_name,
|
||||||
|
list(param_value.shape),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if _moe_load_state["mode"] == "4bit":
|
||||||
|
replace_parameter_4bit(
|
||||||
|
mod,
|
||||||
|
pname,
|
||||||
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
|
quant_type=_moe_load_state["quant_type"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
replace_parameter_8bit(mod, pname)
|
||||||
|
_moe_load_state["count"] += 1
|
||||||
|
|
||||||
|
# Release the bf16 tensor so CUDA memory is freed immediately.
|
||||||
|
param_value.data = torch.empty(0, device="cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
|
_moe_load_state["patched"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_quantized_count():
|
||||||
|
"""Return the number of expert parameters quantized during loading."""
|
||||||
|
return _moe_load_state["count"]
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_target_parameters_matching():
|
||||||
|
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||||
|
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
|
|
||||||
|
original_inject = BaseTuner._inject_parameters
|
||||||
|
|
||||||
|
def _patched_inject_parameters(
|
||||||
|
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||||
|
):
|
||||||
|
# Patch target_parameters to use full paths for parametrized modules
|
||||||
|
original_targets = list(peft_config.target_parameters)
|
||||||
|
expanded = set(original_targets)
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
if not hasattr(module, "parametrizations"):
|
||||||
|
continue
|
||||||
|
for target in original_targets:
|
||||||
|
mod_path, _, param_name = target.rpartition(".")
|
||||||
|
if (
|
||||||
|
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||||
|
) and hasattr(module, param_name):
|
||||||
|
expanded.add(f"{module_name}.{param_name}")
|
||||||
|
|
||||||
|
peft_config.target_parameters = sorted(expanded)
|
||||||
|
try:
|
||||||
|
return original_inject(
|
||||||
|
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
peft_config.target_parameters = original_targets
|
||||||
|
|
||||||
|
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||||
|
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||||
|
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||||
@@ -22,6 +22,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
"qwen3_next",
|
"qwen3_next",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
@@ -37,6 +39,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
|
"glm4_moe",
|
||||||
"smollm3",
|
"smollm3",
|
||||||
"granite",
|
"granite",
|
||||||
"granitemoe",
|
"granitemoe",
|
||||||
|
|||||||
@@ -258,6 +258,32 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_5ProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Qwen3.5 (early-fusion VLM)"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = "<|image_pad|>" # nosec
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
self.video_token = "<|video_pad|>" # nosec
|
||||||
|
self.video_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.video_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_labels(self, input_ids):
|
||||||
|
labels = super().process_labels(input_ids)
|
||||||
|
labels[labels == self.video_token_id] = -100
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
||||||
"""Processing Strategy class for Gemma3"""
|
"""Processing Strategy class for Gemma3"""
|
||||||
|
|
||||||
@@ -562,6 +588,10 @@ def get_processing_strategy(
|
|||||||
return Qwen2VLProcessingStrategy(
|
return Qwen2VLProcessingStrategy(
|
||||||
**processing_kwargs,
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
|
if chat_template_type in ["qwen3_5", "qwen3_5_moe"]:
|
||||||
|
return Qwen3_5ProcessingStrategy(
|
||||||
|
**processing_kwargs,
|
||||||
|
)
|
||||||
if chat_template_type == "gemma3":
|
if chat_template_type == "gemma3":
|
||||||
return Gemma3ProcessingStrategy(
|
return Gemma3ProcessingStrategy(
|
||||||
**processing_kwargs,
|
**processing_kwargs,
|
||||||
|
|||||||
@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
):
|
):
|
||||||
# check if message_property_mappings is None or empty dict
|
# check if message_property_mappings is None or empty dict
|
||||||
if message_property_mappings is None or (not message_property_mappings):
|
if message_property_mappings is None or (not message_property_mappings):
|
||||||
default_message_property_mappings_keys = ["role", "content", "tool"]
|
|
||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
prop: prop for prop in default_message_property_mappings_keys
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
}
|
}
|
||||||
if template_thinking_key and field_thinking:
|
if template_thinking_key and field_thinking:
|
||||||
message_property_mappings[template_thinking_key] = field_thinking
|
message_property_mappings[template_thinking_key] = field_thinking
|
||||||
|
|||||||
@@ -86,9 +86,21 @@ def setup_model_and_tokenizer(
|
|||||||
if model.generation_config is not None:
|
if model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(
|
model_properties = model.config.to_dict()
|
||||||
event_type="model-load", properties=model.config.to_dict()
|
try:
|
||||||
)
|
model_properties["num_parameters"] = model.num_parameters()
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
model_properties["num_parameters"] = sum(p.numel() for p in model.parameters())
|
||||||
|
# if the num_parameters is less than 2B, let's round to nearest 100M, else round to nearest 1B
|
||||||
|
if model_properties["num_parameters"] < 2e9:
|
||||||
|
model_properties["num_parameters_est"] = (
|
||||||
|
f"{round(model_properties['num_parameters'] / 1e8) * 100}M"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_properties["num_parameters_est"] = (
|
||||||
|
f"{round(model_properties['num_parameters'] / 1e9)}B"
|
||||||
|
)
|
||||||
|
TELEMETRY_MANAGER.send_event(event_type="model-load", properties=model_properties)
|
||||||
if peft_config:
|
if peft_config:
|
||||||
TELEMETRY_MANAGER.send_event(
|
TELEMETRY_MANAGER.send_event(
|
||||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||||
|
|||||||
123
src/axolotl/utils/chat_templates/templates/qwen3_5.jinja
Normal file
123
src/axolotl/utils/chat_templates/templates/qwen3_5.jinja
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
{%- if tools %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- messages[0].content + '\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "\n" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
|
{#- Determine the real last index: use provided value or default to messages length - 1 #}
|
||||||
|
{%- if real_last_index is defined and real_last_index is not none %}
|
||||||
|
{%- set ns.real_last_index = real_last_index %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set ns.real_last_index = messages|length - 1 %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages[::-1] %}
|
||||||
|
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||||
|
{%- set ns.multi_step_tool = false %}
|
||||||
|
{%- set ns.last_query_index = index %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if ns.multi_step_tool and message.role == "user" %}
|
||||||
|
{%- set ns.multi_step_tool = false %}
|
||||||
|
{%- set ns.last_query_index = index %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' }}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{{- message.content }}
|
||||||
|
{%- else %}
|
||||||
|
{%- for content in message['content'] %}
|
||||||
|
{%- if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
||||||
|
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
||||||
|
{%- elif content['type'] == 'video' or 'video' in content %}
|
||||||
|
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
|
||||||
|
{%- elif 'text' in content %}
|
||||||
|
{{- content['text'] }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{%- if message['content'] is string %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set content = '' %}
|
||||||
|
{%- for item in message['content'] %}
|
||||||
|
{%- if 'text' in item %}
|
||||||
|
{%- set content = content + item['text'] %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
||||||
|
{%- set reasoning_content = message.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in content %}
|
||||||
|
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if loop.index0 > ns.last_query_index %}
|
||||||
|
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message.tool_calls %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if (loop.first and content) or (not loop.first) %}
|
||||||
|
{{- '\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tool_call.function %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<tool_call>\n{"name": "' }}
|
||||||
|
{{- tool_call.name }}
|
||||||
|
{{- '", "arguments": ' }}
|
||||||
|
{%- if tool_call.arguments is string %}
|
||||||
|
{{- tool_call.arguments }}
|
||||||
|
{%- else %}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '}\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||||
|
{{- '<think>\n\n</think>\n\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<think>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
@@ -6,7 +6,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import (
|
||||||
|
is_torch_greater_or_equal,
|
||||||
|
is_torch_npu_available,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
@@ -81,8 +84,15 @@ def resolve_dtype(cfg):
|
|||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
if cfg.tf32:
|
||||||
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
|
torch.set_float32_matmul_precision("high")
|
||||||
|
if is_torch_greater_or_equal("2.9.0"):
|
||||||
|
torch.backends.fp32_precision = "tf32"
|
||||||
|
torch.backends.cuda.matmul.fp32_precision = "tf32"
|
||||||
|
torch.backends.cudnn.fp32_precision = "tf32"
|
||||||
|
else:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
if cfg.bf16:
|
if cfg.bf16:
|
||||||
cfg.fp16 = False
|
cfg.fp16 = False
|
||||||
|
|
||||||
@@ -119,7 +129,12 @@ def normalize_config(cfg):
|
|||||||
if cfg.world_size != 1:
|
if cfg.world_size != 1:
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
|
||||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
effective_world_size = (
|
||||||
|
cfg.world_size
|
||||||
|
// (cfg.context_parallel_size or 1)
|
||||||
|
// (cfg.tensor_parallel_size or 1)
|
||||||
|
)
|
||||||
|
cfg.batch_size = cfg.batch_size * effective_world_size
|
||||||
|
|
||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
# delay resolving dtype until on worker node when launching with ray
|
# delay resolving dtype until on worker node when launching with ray
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ def _get_remote_filesystem(
|
|||||||
try:
|
try:
|
||||||
import gcsfs
|
import gcsfs
|
||||||
|
|
||||||
storage_options = {"token": None} # type: ignore
|
storage_options = {"token": None} # type: ignore # nosec B105
|
||||||
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao.
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
@@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
|
|
||||||
|
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(
|
def get_quantization_config(
|
||||||
weight_dtype: TorchAOQuantDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
@@ -109,6 +117,19 @@ def get_quantization_config(
|
|||||||
if group_size is not None and group_size != 16:
|
if group_size is not None and group_size != 16:
|
||||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||||
return NVFP4InferenceConfig()
|
return NVFP4InferenceConfig()
|
||||||
|
|
||||||
|
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
|
|
||||||
|
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||||
|
block_size = group_size if group_size is not None else 32
|
||||||
|
if block_size != 32:
|
||||||
|
raise ValueError(
|
||||||
|
"MXFP4 quantization must use a block_size (group_size) of 32"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||||
)
|
)
|
||||||
@@ -179,7 +200,13 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=activation_dtype,
|
activation_dtype=activation_dtype,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
qat_config = QATConfig(base_config)
|
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||||
|
qat_config = QATConfig(
|
||||||
|
activation_config=base_config,
|
||||||
|
weight_config=base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qat_config = QATConfig(base_config)
|
||||||
quantize_(model, qat_config)
|
quantize_(model, qat_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
# activation fake quantization is not supported for embedding layers
|
# activation fake quantization is not supported for embedding layers
|
||||||
@@ -188,7 +215,12 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
embedding_qat_config = QATConfig(embedding_base_config)
|
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||||
|
embedding_qat_config = QATConfig(
|
||||||
|
weight_config=embedding_base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_qat_config = QATConfig(embedding_base_config)
|
||||||
quantize_(
|
quantize_(
|
||||||
model,
|
model,
|
||||||
embedding_qat_config,
|
embedding_qat_config,
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Whether to perform weighting in DPO trainer"
|
"description": "Whether to perform weighting in DPO trainer"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
|
|
||||||
@@ -183,7 +182,6 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
dpo_generate_during_eval: bool | None = None
|
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
@@ -629,6 +627,17 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quantize_moe_experts: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Quantize MoE expert weights on load to reduce VRAM. "
|
||||||
|
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
|
||||||
|
"Requires CUDA (not compatible with ROCm or other backends). "
|
||||||
|
"Note: total parameter count may be reported incorrectly when enabled "
|
||||||
|
"(trainable param count is correct)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
scaling_softmax: bool | None = Field(
|
scaling_softmax: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1289,6 +1298,31 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_quantize_moe_experts(cls, data):
|
||||||
|
if data.get("quantize_moe_experts"):
|
||||||
|
if data.get("lora_target_linear"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_target_linear is not compatible with quantize_moe_experts. "
|
||||||
|
"Use lora_target_parameters to target expert weights instead."
|
||||||
|
)
|
||||||
|
if data.get("adapter") not in ("lora", "qlora"):
|
||||||
|
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
|
||||||
|
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
|
||||||
|
raise ValueError(
|
||||||
|
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
data.get("capabilities")
|
||||||
|
and data["capabilities"].get("compute_capability")
|
||||||
|
and not data["capabilities"]["compute_capability"].startswith("sm_")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_auto_enable_lora_kernels(cls, data):
|
def check_auto_enable_lora_kernels(cls, data):
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ class DeprecatedParameters(BaseModel):
|
|||||||
evaluation_strategy: str | None = None
|
evaluation_strategy: str | None = None
|
||||||
eval_table_size: int | None = None
|
eval_table_size: int | None = None
|
||||||
eval_max_new_tokens: int | None = None
|
eval_max_new_tokens: int | None = None
|
||||||
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -78,6 +80,26 @@ class DeprecatedParameters(BaseModel):
|
|||||||
)
|
)
|
||||||
return eval_max_new_tokens
|
return eval_max_new_tokens
|
||||||
|
|
||||||
|
@field_validator("dpo_use_logits_to_keep")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):
|
||||||
|
if dpo_use_logits_to_keep is not None:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`dpo_use_logits_to_keep` is no longer supported, "
|
||||||
|
"it has been removed in TRL >= 0.29.0"
|
||||||
|
)
|
||||||
|
return dpo_use_logits_to_keep
|
||||||
|
|
||||||
|
@field_validator("dpo_generate_during_eval")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
|
||||||
|
if dpo_generate_during_eval is not None:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`dpo_generate_during_eval` is no longer supported, "
|
||||||
|
"it has been removed in TRL >= 0.29.0"
|
||||||
|
)
|
||||||
|
return dpo_generate_during_eval
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""Parameters that have been remapped to other names"""
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum):
|
|||||||
int8 = torch.int8
|
int8 = torch.int8
|
||||||
float8_e4m3fn = torch.float8_e4m3fn
|
float8_e4m3fn = torch.float8_e4m3fn
|
||||||
nvfp4 = "nvfp4"
|
nvfp4 = "nvfp4"
|
||||||
|
mxfp4 = "mxfp4"
|
||||||
|
|
||||||
def from_string(str):
|
def from_string(str):
|
||||||
if str == "int4":
|
if str == "int4":
|
||||||
@@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum):
|
|||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if str == "nvfp4":
|
if str == "nvfp4":
|
||||||
return TorchAOQuantDType.nvfp4
|
return TorchAOQuantDType.nvfp4
|
||||||
|
if str == "mxfp4":
|
||||||
|
return TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
@@ -56,6 +59,7 @@ class ChatTemplate(str, Enum):
|
|||||||
jinja = "jinja"
|
jinja = "jinja"
|
||||||
qwen_25 = "qwen_25"
|
qwen_25 = "qwen_25"
|
||||||
qwen3 = "qwen3"
|
qwen3 = "qwen3"
|
||||||
|
qwen3_5 = "qwen3_5"
|
||||||
falcon_h1 = "falcon_h1"
|
falcon_h1 = "falcon_h1"
|
||||||
tokenizer_default = "tokenizer_default"
|
tokenizer_default = "tokenizer_default"
|
||||||
exaone = "exaone"
|
exaone = "exaone"
|
||||||
|
|||||||
@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_lora_target_parameters_dropout(self):
|
||||||
|
if (
|
||||||
|
self.lora_target_parameters
|
||||||
|
and self.lora_dropout
|
||||||
|
and self.lora_dropout != 0.0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_dropout must be 0 when lora_target_parameters is set. "
|
||||||
|
"PEFT's ParamWrapper does not support lora_dropout != 0."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
|||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if v == "nvfp4":
|
if v == "nvfp4":
|
||||||
return TorchAOQuantDType.nvfp4
|
return TorchAOQuantDType.nvfp4
|
||||||
|
if v == "mxfp4":
|
||||||
|
return TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -986,23 +986,6 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def lr_groups_ao_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.loraplus_lr_ratio is not None
|
|
||||||
or self.embedding_lr_scale is not None
|
|
||||||
or self.embedding_lr is not None
|
|
||||||
or self.lr_groups is not None
|
|
||||||
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
|
||||||
# TODO(wing): remove this once ao>0.12.0
|
|
||||||
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
|
|
||||||
raise ValueError(
|
|
||||||
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
|
|
||||||
"supported with ao low-bit optimizers until ao>0.12.0. "
|
|
||||||
"Please refer to https://github.com/pytorch/ao/pull/2606."
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
|
|||||||
@@ -457,8 +457,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.context_parallel_size
|
|
||||||
* cfg.tensor_parallel_size
|
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||||
@@ -497,14 +495,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}")
|
LOG.debug(f"data_loader_len: {data_loader_len}")
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||||
math.floor(
|
|
||||||
data_loader_len
|
|
||||||
* cfg.num_epochs
|
|
||||||
* cfg.context_parallel_size
|
|
||||||
* cfg.tensor_parallel_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if cfg.dataloader_drop_last:
|
if cfg.dataloader_drop_last:
|
||||||
# drop the last batch for each epoch
|
# drop the last batch for each epoch
|
||||||
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||||
@@ -525,13 +516,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
len(train_dataset)
|
|
||||||
* cfg.num_epochs
|
|
||||||
* cfg.context_parallel_size
|
|
||||||
* cfg.tensor_parallel_size
|
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}")
|
LOG.debug(f"total_num_steps: {total_num_steps}")
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -94,7 +94,6 @@ def fixture_dpo_cfg(base_cfg):
|
|||||||
{
|
{
|
||||||
"rl": RLType.DPO,
|
"rl": RLType.DPO,
|
||||||
"dpo_use_weighting": True,
|
"dpo_use_weighting": True,
|
||||||
"dpo_use_logits_to_keep": True,
|
|
||||||
"dpo_label_smoothing": 0.1,
|
"dpo_label_smoothing": 0.1,
|
||||||
"beta": 0.1, # DPO beta
|
"beta": 0.1, # DPO beta
|
||||||
}
|
}
|
||||||
@@ -148,9 +147,16 @@ def fixture_grpo_cfg(base_cfg):
|
|||||||
),
|
),
|
||||||
# Must be evenly divisible by num_generations
|
# Must be evenly divisible by num_generations
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "openai/gsm8k",
|
||||||
|
"name": "main",
|
||||||
|
"split": "train[:1%]",
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return cfg
|
return DictDefault(cfg)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="ipo_cfg")
|
@pytest.fixture(name="ipo_cfg")
|
||||||
@@ -334,6 +340,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
|||||||
try:
|
try:
|
||||||
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
||||||
training_arguments, _ = builder._build_training_arguments(100)
|
training_arguments, _ = builder._build_training_arguments(100)
|
||||||
|
builder.train_dataset = MagicMock()
|
||||||
|
|
||||||
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
||||||
# GRPO specific
|
# GRPO specific
|
||||||
@@ -363,7 +370,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
|||||||
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
||||||
# IPO specific
|
# IPO specific
|
||||||
assert training_arguments.beta == 0.1
|
assert training_arguments.beta == 0.1
|
||||||
assert training_arguments.loss_type == "ipo"
|
assert training_arguments.loss_type == ["ipo"]
|
||||||
assert training_arguments.label_smoothing == 0
|
assert training_arguments.label_smoothing == 0
|
||||||
|
|
||||||
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
|
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
|
||||||
@@ -529,13 +536,11 @@ class TestHFCausalTrainerBuilder:
|
|||||||
"cfg_string",
|
"cfg_string",
|
||||||
[
|
[
|
||||||
"sft_cfg",
|
"sft_cfg",
|
||||||
"rm_cfg",
|
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
|
||||||
"prm_cfg",
|
"prm_cfg",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_custom_optimizer_cls_and_kwargs(
|
def test_builder_w_rm_trainers(self, request, cfg_string, model, tokenizer):
|
||||||
self, request, cfg_string, model, tokenizer
|
|
||||||
):
|
|
||||||
cfg = request.getfixturevalue(cfg_string)
|
cfg = request.getfixturevalue(cfg_string)
|
||||||
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
cfg["optimizer"] = "muon"
|
cfg["optimizer"] = "muon"
|
||||||
|
|||||||
288
tests/e2e/integrations/test_sonicmoe.py
Normal file
288
tests/e2e/integrations/test_sonicmoe.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
End-to-end gradient and convergence tests for SonicMoE integration.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
|
||||||
|
- sonicmoe package installed
|
||||||
|
- transformers with Qwen3MoE support
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
|
||||||
|
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
|
||||||
|
pytest.mark.skipif(
|
||||||
|
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
|
||||||
|
),
|
||||||
|
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tiny_qwen3_config():
|
||||||
|
"""Create a minimal Qwen3MoE config for fast testing."""
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
config = AutoConfig.for_model("qwen3_moe")
|
||||||
|
config.hidden_size = 512
|
||||||
|
config.intermediate_size = 1024
|
||||||
|
config.moe_intermediate_size = 64
|
||||||
|
config.num_attention_heads = 16
|
||||||
|
config.num_key_value_heads = 2
|
||||||
|
config.head_dim = 32
|
||||||
|
config.num_hidden_layers = 2
|
||||||
|
config.num_experts = 8
|
||||||
|
config.num_experts_per_tok = 2
|
||||||
|
config.vocab_size = 1000
|
||||||
|
config.max_position_embeddings = 128
|
||||||
|
config.norm_topk_prob = True
|
||||||
|
config.torch_dtype = torch.bfloat16
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _interleave_gate_up_weights(model):
|
||||||
|
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||||
|
interleave_gate_up,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "gate_up_proj" in name:
|
||||||
|
param.copy_(interleave_gate_up(param))
|
||||||
|
|
||||||
|
|
||||||
|
def _unpatch_sonicmoe():
|
||||||
|
"""Restore original forward on the MoE block class if it was patched."""
|
||||||
|
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||||
|
|
||||||
|
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
|
||||||
|
if hasattr(moe_cls, "_original_forward"):
|
||||||
|
moe_cls.forward = moe_cls._original_forward
|
||||||
|
del moe_cls._original_forward
|
||||||
|
|
||||||
|
|
||||||
|
class TestSonicMoEForwardCorrectness:
|
||||||
|
"""Verify SonicMoE-patched model produces same output as original."""
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
_unpatch_sonicmoe()
|
||||||
|
|
||||||
|
def test_forward_output_matches(self):
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||||
|
|
||||||
|
config = _create_tiny_qwen3_config()
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||||
|
|
||||||
|
# Original model
|
||||||
|
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_orig = model_orig(input_ids)
|
||||||
|
|
||||||
|
# Patched model (same weights, interleaved for SonicMoE)
|
||||||
|
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
model_patched.load_state_dict(model_orig.state_dict())
|
||||||
|
|
||||||
|
patch_sonicmoe("qwen3_moe")
|
||||||
|
_interleave_gate_up_weights(model_patched)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_patched = model_patched(input_ids)
|
||||||
|
|
||||||
|
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
|
||||||
|
assert torch.allclose(
|
||||||
|
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
|
||||||
|
), f"Output mismatch: max diff={max_diff:.6f}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSonicMoEGradientCorrectness:
|
||||||
|
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
_unpatch_sonicmoe()
|
||||||
|
|
||||||
|
def test_gradients_match(self):
|
||||||
|
"""Verify all parameter gradients match between original and patched."""
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||||
|
deinterleave_gate_up,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = _create_tiny_qwen3_config()
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||||
|
|
||||||
|
# ---------- Original model ----------
|
||||||
|
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
out_orig = model_orig(input_ids, labels=input_ids)
|
||||||
|
out_orig.loss.backward()
|
||||||
|
grads_orig = {
|
||||||
|
n: p.grad.float().clone()
|
||||||
|
for n, p in model_orig.named_parameters()
|
||||||
|
if p.grad is not None
|
||||||
|
}
|
||||||
|
loss_orig = out_orig.loss.item()
|
||||||
|
|
||||||
|
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
|
||||||
|
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
model_patched.load_state_dict(model_orig.state_dict())
|
||||||
|
|
||||||
|
patch_sonicmoe("qwen3_moe")
|
||||||
|
_interleave_gate_up_weights(model_patched)
|
||||||
|
|
||||||
|
out_patched = model_patched(input_ids, labels=input_ids)
|
||||||
|
out_patched.loss.backward()
|
||||||
|
grads_patched = {}
|
||||||
|
for n, p in model_patched.named_parameters():
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
g = p.grad.float().clone()
|
||||||
|
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
|
||||||
|
if "gate_up_proj" in n:
|
||||||
|
g = deinterleave_gate_up(g)
|
||||||
|
grads_patched[n] = g
|
||||||
|
loss_patched = out_patched.loss.item()
|
||||||
|
|
||||||
|
# ---------- Compare ----------
|
||||||
|
assert abs(loss_orig - loss_patched) < 0.5, (
|
||||||
|
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# All parameters with gradients in original should have them in patched
|
||||||
|
missing = set(grads_orig.keys()) - set(grads_patched.keys())
|
||||||
|
assert not missing, f"Missing gradients in patched model: {missing}"
|
||||||
|
|
||||||
|
# Compare gradient values
|
||||||
|
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
|
||||||
|
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
|
||||||
|
mismatches = []
|
||||||
|
for name in grads_orig:
|
||||||
|
if name not in grads_patched:
|
||||||
|
continue
|
||||||
|
g_orig = grads_orig[name]
|
||||||
|
g_patched = grads_patched[name]
|
||||||
|
max_diff = (g_orig - g_patched).abs().max().item()
|
||||||
|
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
|
||||||
|
|
||||||
|
if rel_diff > 0.1 and max_diff > 1e-2:
|
||||||
|
mismatches.append(
|
||||||
|
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not mismatches, (
|
||||||
|
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
|
||||||
|
+ "\n".join(mismatches)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_router_weights_receive_gradients(self):
|
||||||
|
"""Verify that router (gate) weights get non-zero gradients."""
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||||
|
|
||||||
|
config = _create_tiny_qwen3_config()
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
patch_sonicmoe("qwen3_moe")
|
||||||
|
_interleave_gate_up_weights(model)
|
||||||
|
|
||||||
|
out = model(input_ids, labels=input_ids)
|
||||||
|
out.loss.backward()
|
||||||
|
|
||||||
|
gate_grads_found = False
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "gate" in name and "weight" in name:
|
||||||
|
gate_grads_found = True
|
||||||
|
assert param.grad is not None, f"No gradient for router: {name}"
|
||||||
|
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
|
||||||
|
|
||||||
|
assert gate_grads_found, "No gate.weight parameters found in model"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSonicMoETrainingConvergence:
|
||||||
|
"""Verify loss decreases during training with SonicMoE."""
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
_unpatch_sonicmoe()
|
||||||
|
|
||||||
|
def test_loss_decreases(self):
|
||||||
|
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||||
|
|
||||||
|
config = _create_tiny_qwen3_config()
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
patch_sonicmoe("qwen3_moe")
|
||||||
|
_interleave_gate_up_weights(model)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for step in range(30):
|
||||||
|
out = model(input_ids, labels=input_ids)
|
||||||
|
loss = out.loss
|
||||||
|
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||||
|
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
assert losses[-1] < losses[0], (
|
||||||
|
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_expert_weights_update(self):
|
||||||
|
"""Verify expert weights change during training (not frozen)."""
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||||
|
|
||||||
|
config = _create_tiny_qwen3_config()
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||||
|
patch_sonicmoe("qwen3_moe")
|
||||||
|
_interleave_gate_up_weights(model)
|
||||||
|
|
||||||
|
# Snapshot expert weights before training
|
||||||
|
expert_weights_before = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "experts" in name:
|
||||||
|
expert_weights_before[name] = param.data.clone()
|
||||||
|
|
||||||
|
assert expert_weights_before, "No expert parameters found"
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||||
|
for _ in range(5):
|
||||||
|
out = model(input_ids, labels=input_ids)
|
||||||
|
out.loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Check that expert weights changed
|
||||||
|
changed = 0
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name in expert_weights_before:
|
||||||
|
if not torch.equal(param.data, expert_weights_before[name]):
|
||||||
|
changed += 1
|
||||||
|
|
||||||
|
assert changed > 0, "No expert weights changed after 5 training steps"
|
||||||
@@ -8,6 +8,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard
|
from .utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
@@ -130,3 +132,32 @@ class TestQATLlama:
|
|||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMXFP4Schema:
|
||||||
|
"""Test MXFP4 schema validation"""
|
||||||
|
|
||||||
|
def test_validate_mxfp4_dtype(self):
|
||||||
|
result = validate_ao_dtype("mxfp4")
|
||||||
|
assert result == TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
|
def test_qat_config_with_mxfp4(self):
|
||||||
|
"""Test QATConfig accepts mxfp4 weight_dtype"""
|
||||||
|
config = QATConfig(
|
||||||
|
weight_dtype="mxfp4",
|
||||||
|
group_size=32,
|
||||||
|
quantize_embedding=False,
|
||||||
|
)
|
||||||
|
assert config.weight_dtype == TorchAOQuantDType.mxfp4
|
||||||
|
assert config.group_size == 32
|
||||||
|
|
||||||
|
def test_qat_config_mxfp4_invalid_group_size(self):
|
||||||
|
"""Test that invalid group_size raises appropriate error during quantization"""
|
||||||
|
# Note: Schema validation doesn't check group_size compatibility,
|
||||||
|
# that happens in get_quantization_config
|
||||||
|
config = QATConfig(
|
||||||
|
weight_dtype="mxfp4",
|
||||||
|
group_size=16, # Invalid for mxfp4, but schema allows it
|
||||||
|
)
|
||||||
|
assert config.group_size == 16 # Schema accepts it
|
||||||
|
# Actual validation happens at runtime in get_quantization_config
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
from torchao.quantization import LinearActivationQuantizedTensor
|
from torchao.quantization import LinearActivationQuantizedTensor
|
||||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
@@ -117,6 +118,21 @@ class TestQuantization:
|
|||||||
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||||
assert isinstance(config, expected_type)
|
assert isinstance(config, expected_type)
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_get_ptq_config_mxfp4(self):
|
||||||
|
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
||||||
|
assert isinstance(config, MXFakeQuantizeConfig)
|
||||||
|
assert config.block_size == 32
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_get_ptq_config_mxfp4_invalid_group_size(self):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="MXFP4 quantization must use a block_size"
|
||||||
|
):
|
||||||
|
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
|
||||||
|
|
||||||
@requires_cuda_ge_8_9
|
@requires_cuda_ge_8_9
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
def test_get_ptq_config_int4_weight_only(self):
|
def test_get_ptq_config_int4_weight_only(self):
|
||||||
@@ -262,6 +278,35 @@ class TestQuantization:
|
|||||||
else:
|
else:
|
||||||
assert child.activation_fake_quantizer is None
|
assert child.activation_fake_quantizer is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||||
|
[
|
||||||
|
(TorchAOQuantDType.mxfp4, None, 32, False),
|
||||||
|
(TorchAOQuantDType.mxfp4, None, 32, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_prepare_model_for_qat_mxfp4(
|
||||||
|
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||||
|
):
|
||||||
|
prepare_model_for_qat(
|
||||||
|
model,
|
||||||
|
weight_dtype,
|
||||||
|
group_size,
|
||||||
|
activation_dtype,
|
||||||
|
quantize_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
if quantize_embedding:
|
||||||
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
|
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||||
|
|
||||||
|
for child in list(model.children()):
|
||||||
|
if isinstance(child, torch.nn.Linear):
|
||||||
|
assert isinstance(child, FakeQuantizedLinear)
|
||||||
|
assert hasattr(child, "weight_fake_quantizer")
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@requires_cuda_ge_8_9
|
@requires_cuda_ge_8_9
|
||||||
def test_convert_qat_model(self, model):
|
def test_convert_qat_model(self, model):
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ def check_tensorboard(
|
|||||||
lt_val: float,
|
lt_val: float,
|
||||||
assertion_err: str,
|
assertion_err: str,
|
||||||
rtol: float = 0.02,
|
rtol: float = 0.02,
|
||||||
|
gt_zero: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
@@ -194,6 +195,8 @@ def check_tensorboard(
|
|||||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||||
else:
|
else:
|
||||||
assert df.value.values[-1] < lt_val, assertion_err
|
assert df.value.values[-1] < lt_val, assertion_err
|
||||||
|
if gt_zero:
|
||||||
|
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
|
||||||
|
|
||||||
|
|
||||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
Unit tests for scattermoe-lora code-review fixes.
|
Unit tests for scattermoe-lora code-review fixes.
|
||||||
|
|
||||||
Tests cover:
|
Tests cover:
|
||||||
- KernelsArgs validator: disable_mlp_kernel_scattermoe
|
- KernelsArgs validator: disable_mlp_kernel
|
||||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||||
- single2scatter: non-aligned K/N dimensions
|
- single2scatter: non-aligned K/N dimensions
|
||||||
@@ -20,12 +20,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
|
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestKernelsArgsValidator:
|
class TestKernelsArgsValidator:
|
||||||
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
|
"""Test that disable_mlp_kernel sets both flags correctly.
|
||||||
|
|
||||||
These tests call the validator classmethod directly on raw dicts,
|
These tests call the validator classmethod directly on raw dicts,
|
||||||
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
||||||
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
|
|||||||
"use_scattermoe": True,
|
"use_scattermoe": True,
|
||||||
"lora_mlp_kernel": True,
|
"lora_mlp_kernel": True,
|
||||||
}
|
}
|
||||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
result = KernelsArgs.disable_mlp_kernel(data)
|
||||||
assert result["lora_mlp_kernel"] is False
|
assert result["lora_mlp_kernel"] is False
|
||||||
assert result["mlp_kernel"] is False
|
assert result["mlp_kernel"] is False
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
|
|||||||
"use_kernels": True,
|
"use_kernels": True,
|
||||||
"use_scattermoe": True,
|
"use_scattermoe": True,
|
||||||
}
|
}
|
||||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
result = KernelsArgs.disable_mlp_kernel(data)
|
||||||
assert result["mlp_kernel"] is False
|
assert result["mlp_kernel"] is False
|
||||||
# lora_mlp_kernel was not in data, should not be added
|
# lora_mlp_kernel was not in data, should not be added
|
||||||
assert "lora_mlp_kernel" not in result
|
assert "lora_mlp_kernel" not in result
|
||||||
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
|
|||||||
"use_scattermoe": True,
|
"use_scattermoe": True,
|
||||||
"lora_mlp_kernel": False,
|
"lora_mlp_kernel": False,
|
||||||
}
|
}
|
||||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
result = KernelsArgs.disable_mlp_kernel(data)
|
||||||
assert result["lora_mlp_kernel"] is False
|
assert result["lora_mlp_kernel"] is False
|
||||||
|
|
||||||
def test_no_change_when_scattermoe_disabled(self):
|
def test_no_change_when_scattermoe_disabled(self):
|
||||||
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
|
|||||||
"use_scattermoe": False,
|
"use_scattermoe": False,
|
||||||
"lora_mlp_kernel": True,
|
"lora_mlp_kernel": True,
|
||||||
}
|
}
|
||||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
result = KernelsArgs.disable_mlp_kernel(data)
|
||||||
assert result["lora_mlp_kernel"] is True
|
assert result["lora_mlp_kernel"] is True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
428
tests/integrations/test_sonicmoe.py
Normal file
428
tests/integrations/test_sonicmoe.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
"""Unit tests for the SonicMoE integration."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.args import KernelsArgs
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||||
|
sigmoid_topk_routing,
|
||||||
|
softmax_topk_routing,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||||
|
ConcatenatedToInterleaved,
|
||||||
|
InterleavedToConcatenated,
|
||||||
|
register_sonicmoe_weight_converter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKernelsArgs:
|
||||||
|
def test_mutual_exclusivity_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot use both"):
|
||||||
|
KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True})
|
||||||
|
|
||||||
|
def test_sonicmoe_only(self):
|
||||||
|
result = KernelsArgs.model_validate({"use_sonicmoe": True})
|
||||||
|
assert result.use_sonicmoe is True
|
||||||
|
assert result.use_scattermoe is None
|
||||||
|
|
||||||
|
def test_scattermoe_only(self):
|
||||||
|
result = KernelsArgs.model_validate({"use_scattermoe": True})
|
||||||
|
assert result.use_scattermoe is True
|
||||||
|
assert result.use_sonicmoe is None
|
||||||
|
|
||||||
|
def test_neither_set(self):
|
||||||
|
result = KernelsArgs.model_validate({})
|
||||||
|
assert result.use_scattermoe is None
|
||||||
|
assert result.use_sonicmoe is None
|
||||||
|
|
||||||
|
def test_disables_mlp_kernel_when_sonicmoe(self):
|
||||||
|
data = {"use_sonicmoe": True, "lora_mlp_kernel": True}
|
||||||
|
result = KernelsArgs.disable_mlp_kernel(data)
|
||||||
|
assert result["lora_mlp_kernel"] is False
|
||||||
|
assert result["mlp_kernel"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcatenatedToInterleaved:
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tensor(self):
|
||||||
|
"""Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values."""
|
||||||
|
E, I, H = 2, 2, 3 # noqa: E741
|
||||||
|
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
|
||||||
|
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
|
||||||
|
return torch.cat([gate, up], dim=1)
|
||||||
|
|
||||||
|
def test_interleave_rows_alternate(self, sample_tensor):
|
||||||
|
op = ConcatenatedToInterleaved(dim=1)
|
||||||
|
result = op.convert(
|
||||||
|
{"test": sample_tensor},
|
||||||
|
source_patterns=["test"],
|
||||||
|
target_patterns=["test"],
|
||||||
|
)
|
||||||
|
interleaved = result["test"]
|
||||||
|
|
||||||
|
# For expert 0: even rows should be gate, odd rows should be up
|
||||||
|
E, two_I, H = sample_tensor.shape
|
||||||
|
I = two_I // 2 # noqa: E741
|
||||||
|
gate_orig = sample_tensor[:, :I, :]
|
||||||
|
up_orig = sample_tensor[:, I:, :]
|
||||||
|
|
||||||
|
assert torch.equal(interleaved[:, 0::2, :], gate_orig)
|
||||||
|
assert torch.equal(interleaved[:, 1::2, :], up_orig)
|
||||||
|
|
||||||
|
def test_interleave_handles_list_input(self, sample_tensor):
|
||||||
|
op = ConcatenatedToInterleaved(dim=1)
|
||||||
|
result = op.convert(
|
||||||
|
{"test": [sample_tensor]},
|
||||||
|
source_patterns=["test"],
|
||||||
|
target_patterns=["test"],
|
||||||
|
)
|
||||||
|
assert result["test"].shape == sample_tensor.shape
|
||||||
|
|
||||||
|
def test_reverse_op_type(self):
|
||||||
|
op = ConcatenatedToInterleaved(dim=1)
|
||||||
|
assert isinstance(op.reverse_op, InterleavedToConcatenated)
|
||||||
|
assert op.reverse_op.dim == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestInterleavedToConcatenated:
|
||||||
|
@pytest.fixture
|
||||||
|
def interleaved_tensor(self):
|
||||||
|
"""Create an interleaved tensor [E=2, 2*I=4, H=3]."""
|
||||||
|
E, I, H = 2, 2, 3 # noqa: E741
|
||||||
|
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
|
||||||
|
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
|
||||||
|
interleaved = torch.empty(E, 2 * I, H)
|
||||||
|
interleaved[:, 0::2, :] = gate
|
||||||
|
interleaved[:, 1::2, :] = up
|
||||||
|
return interleaved
|
||||||
|
|
||||||
|
def test_deinterleave_gate_up_separated(self, interleaved_tensor):
|
||||||
|
op = InterleavedToConcatenated(dim=1)
|
||||||
|
result = op.convert(
|
||||||
|
{"test": interleaved_tensor},
|
||||||
|
source_patterns=["test"],
|
||||||
|
target_patterns=["test"],
|
||||||
|
)
|
||||||
|
concatenated = result["test"]
|
||||||
|
|
||||||
|
E, two_I, H = concatenated.shape
|
||||||
|
I = two_I // 2 # noqa: E741
|
||||||
|
|
||||||
|
# First half should be gate (even rows from interleaved)
|
||||||
|
assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])
|
||||||
|
# Second half should be up (odd rows from interleaved)
|
||||||
|
assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])
|
||||||
|
|
||||||
|
def test_reverse_op_type(self):
|
||||||
|
op = InterleavedToConcatenated(dim=1)
|
||||||
|
assert isinstance(op.reverse_op, ConcatenatedToInterleaved)
|
||||||
|
assert op.reverse_op.dim == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoundTrip:
|
||||||
|
@pytest.fixture
|
||||||
|
def concat_tensor(self):
|
||||||
|
E, I, H = 4, 8, 16 # noqa: E741
|
||||||
|
gate = torch.randn(E, I, H)
|
||||||
|
up = torch.randn(E, I, H)
|
||||||
|
return torch.cat([gate, up], dim=1)
|
||||||
|
|
||||||
|
def test_interleave_then_deinterleave_is_identity(self, concat_tensor):
|
||||||
|
fwd = ConcatenatedToInterleaved(dim=1)
|
||||||
|
rev = InterleavedToConcatenated(dim=1)
|
||||||
|
|
||||||
|
interleaved = fwd.convert(
|
||||||
|
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
recovered = rev.convert(
|
||||||
|
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
|
||||||
|
assert torch.equal(concat_tensor, recovered)
|
||||||
|
|
||||||
|
def test_reverse_op_chain_is_identity(self, concat_tensor):
|
||||||
|
"""Verify that op.reverse_op produces an exact inverse."""
|
||||||
|
op = ConcatenatedToInterleaved(dim=1)
|
||||||
|
rev = op.reverse_op
|
||||||
|
|
||||||
|
interleaved = op.convert(
|
||||||
|
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
recovered = rev.convert(
|
||||||
|
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
|
||||||
|
assert torch.equal(concat_tensor, recovered)
|
||||||
|
|
||||||
|
def test_various_shapes(self):
|
||||||
|
"""Test with different expert counts and dimensions."""
|
||||||
|
fwd = ConcatenatedToInterleaved(dim=1)
|
||||||
|
rev = InterleavedToConcatenated(dim=1)
|
||||||
|
|
||||||
|
for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741
|
||||||
|
concat = torch.randn(E, 2 * I, H)
|
||||||
|
interleaved = fwd.convert(
|
||||||
|
{"k": concat}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
recovered = rev.convert(
|
||||||
|
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
|
||||||
|
)["k"]
|
||||||
|
assert torch.equal(concat, recovered), (
|
||||||
|
f"Failed for shape ({E}, {2 * I}, {H})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightConverterRegistration:
|
||||||
|
def test_register_appends_interleave_op(self):
|
||||||
|
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||||
|
|
||||||
|
register_sonicmoe_weight_converter("qwen3_moe")
|
||||||
|
|
||||||
|
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||||
|
# Find the gate_up_proj converter
|
||||||
|
gate_up_converter = None
|
||||||
|
for conv in modified:
|
||||||
|
if hasattr(conv, "operations") and any(
|
||||||
|
"gate_up_proj" in pat for pat in conv.target_patterns
|
||||||
|
):
|
||||||
|
gate_up_converter = conv
|
||||||
|
break
|
||||||
|
|
||||||
|
assert gate_up_converter is not None
|
||||||
|
assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)
|
||||||
|
|
||||||
|
def test_double_registration_is_idempotent(self):
|
||||||
|
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||||
|
|
||||||
|
register_sonicmoe_weight_converter("qwen3_moe")
|
||||||
|
register_sonicmoe_weight_converter("qwen3_moe")
|
||||||
|
|
||||||
|
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||||
|
for conv in modified:
|
||||||
|
if hasattr(conv, "operations") and any(
|
||||||
|
"gate_up_proj" in pat for pat in conv.target_patterns
|
||||||
|
):
|
||||||
|
interleave_count = sum(
|
||||||
|
isinstance(op, ConcatenatedToInterleaved) for op in conv.operations
|
||||||
|
)
|
||||||
|
assert interleave_count == 1, (
|
||||||
|
f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_register_unsupported_model_type_warns(self):
|
||||||
|
# A model type with no conversion mapping should warn but not raise
|
||||||
|
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
|
||||||
|
"""Create a mock qwen-style MoE block for routing tests."""
|
||||||
|
gate = SimpleNamespace(
|
||||||
|
weight=torch.randn(E, H),
|
||||||
|
top_k=K,
|
||||||
|
num_experts=E,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
)
|
||||||
|
return SimpleNamespace(gate=gate), T, H, E, K
|
||||||
|
|
||||||
|
|
||||||
|
def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):
|
||||||
|
"""Create a mock GLM5-style MoE block for routing tests."""
|
||||||
|
gate = SimpleNamespace(
|
||||||
|
weight=torch.randn(E, H),
|
||||||
|
e_score_correction_bias=torch.zeros(E),
|
||||||
|
)
|
||||||
|
moe_block = SimpleNamespace(
|
||||||
|
gate=gate,
|
||||||
|
top_k=K,
|
||||||
|
n_routed_experts=E,
|
||||||
|
n_group=n_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
)
|
||||||
|
return moe_block, T, H, E, K
|
||||||
|
|
||||||
|
|
||||||
|
def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):
|
||||||
|
"""Create a mock minimax_m2-style MoE block for routing tests.
|
||||||
|
|
||||||
|
minimax_m2 uses sigmoid->topk WITHOUT group selection:
|
||||||
|
- e_score_correction_bias is on the moe_block (not on gate)
|
||||||
|
- No n_group / topk_group attributes
|
||||||
|
- Always normalizes (norm_topk_prob defaults to True)
|
||||||
|
- No routed_scaling_factor (defaults to 1.0)
|
||||||
|
"""
|
||||||
|
gate = SimpleNamespace(
|
||||||
|
weight=torch.randn(E, H),
|
||||||
|
top_k=K,
|
||||||
|
)
|
||||||
|
moe_block = SimpleNamespace(
|
||||||
|
gate=gate,
|
||||||
|
top_k=K,
|
||||||
|
e_score_correction_bias=torch.zeros(E),
|
||||||
|
)
|
||||||
|
return moe_block, T, H, E, K
|
||||||
|
|
||||||
|
|
||||||
|
class TestSoftmaxTopkRouting:
|
||||||
|
def test_output_shapes(self):
|
||||||
|
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert scores.shape == (T * K,)
|
||||||
|
assert token_idx.shape == (T * K,)
|
||||||
|
assert expert_idx.shape == (T * K,)
|
||||||
|
assert logits.shape == (T, E)
|
||||||
|
|
||||||
|
def test_scores_are_float32(self):
|
||||||
|
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
assert scores.dtype == torch.float32
|
||||||
|
|
||||||
|
def test_token_indices_sorted_ascending(self):
|
||||||
|
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# Token indices must be sorted ascending (SonicMoE requirement)
|
||||||
|
diffs = token_idx[1:] - token_idx[:-1]
|
||||||
|
assert (diffs >= 0).all()
|
||||||
|
|
||||||
|
def test_expert_indices_in_range(self):
|
||||||
|
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert (expert_idx >= 0).all()
|
||||||
|
assert (expert_idx < E).all()
|
||||||
|
|
||||||
|
def test_renormalized_scores_sum_to_one(self):
|
||||||
|
moe_block, T, H, E, K = _make_qwen_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||||
|
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSigmoidTopkRouting:
|
||||||
|
def test_output_shapes(self):
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert scores.shape == (T * K,)
|
||||||
|
assert token_idx.shape == (T * K,)
|
||||||
|
assert expert_idx.shape == (T * K,)
|
||||||
|
assert logits.shape == (T, E)
|
||||||
|
|
||||||
|
def test_scores_are_float32(self):
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
assert scores.dtype == torch.float32
|
||||||
|
|
||||||
|
def test_token_indices_sorted_ascending(self):
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
diffs = token_idx[1:] - token_idx[:-1]
|
||||||
|
assert (diffs >= 0).all()
|
||||||
|
|
||||||
|
def test_expert_indices_in_range(self):
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert (expert_idx >= 0).all()
|
||||||
|
assert (expert_idx < E).all()
|
||||||
|
|
||||||
|
def test_scores_are_nonnegative(self):
|
||||||
|
"""Sigmoid outputs are in [0, 1], so scores should be non-negative."""
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
assert (scores >= 0).all()
|
||||||
|
|
||||||
|
def test_scaling_factor_applied(self):
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
# Get scores with scaling_factor=1.0
|
||||||
|
scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# Get scores with scaling_factor=2.0
|
||||||
|
moe_block.routed_scaling_factor = 2.0
|
||||||
|
scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)
|
||||||
|
|
||||||
|
def test_group_selection_restricts_experts(self):
|
||||||
|
"""With n_group=4 and topk_group=1, only 1/4 of experts should be selectable."""
|
||||||
|
moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# Each token's experts should all fall within a single group (size E//n_group=4)
|
||||||
|
expert_idx_2d = expert_idx.reshape(T, K)
|
||||||
|
for t in range(T):
|
||||||
|
experts = expert_idx_2d[t]
|
||||||
|
groups = experts // (E // moe_block.n_group)
|
||||||
|
# All selected experts should be from the same group
|
||||||
|
assert (groups == groups[0]).all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiniMaxM2SigmoidRouting:
|
||||||
|
"""Tests for minimax_m2 routing: sigmoid->topk without group selection."""
|
||||||
|
|
||||||
|
def test_output_shapes(self):
|
||||||
|
"""Validates getattr defaults work: n_group=1, E from gate.weight.shape[0]."""
|
||||||
|
moe_block, T, H, E, K = _make_minimax_m2_moe_block()
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
assert scores.shape == (T * K,)
|
||||||
|
assert token_idx.shape == (T * K,)
|
||||||
|
assert expert_idx.shape == (T * K,)
|
||||||
|
assert logits.shape == (T, E)
|
||||||
|
|
||||||
|
def test_bias_on_block_not_gate(self):
|
||||||
|
"""Verify that e_score_correction_bias on the block (not gate) is used."""
|
||||||
|
T, H, E, K = 8, 16, 8, 2
|
||||||
|
gate = SimpleNamespace(
|
||||||
|
weight=torch.randn(E, H),
|
||||||
|
top_k=K,
|
||||||
|
)
|
||||||
|
# Large positive bias on expert 0 should make it selected more often
|
||||||
|
bias = torch.zeros(E)
|
||||||
|
bias[0] = 100.0
|
||||||
|
moe_block = SimpleNamespace(
|
||||||
|
gate=gate,
|
||||||
|
top_k=K,
|
||||||
|
e_score_correction_bias=bias,
|
||||||
|
)
|
||||||
|
hidden = torch.randn(T, H)
|
||||||
|
|
||||||
|
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# Expert 0 should appear for every token due to the large bias
|
||||||
|
expert_idx_2d = expert_idx.reshape(T, K)
|
||||||
|
for t in range(T):
|
||||||
|
assert 0 in expert_idx_2d[t]
|
||||||
158
tests/integrations/test_sonicmoe_gradients.py
Normal file
158
tests/integrations/test_sonicmoe_gradients.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Gradient correctness tests for SonicMoE routing functions (CPU-only).
|
||||||
|
|
||||||
|
Uses torch.autograd.gradcheck with float32 inputs to match the production
|
||||||
|
code path where routing happens in float32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||||
|
sigmoid_topk_routing,
|
||||||
|
softmax_topk_routing,
|
||||||
|
)
|
||||||
|
|
||||||
|
_GC_EPS = 1e-3
|
||||||
|
_GC_ATOL = 1e-3
|
||||||
|
_GC_RTOL = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def _make_softmax_moe_block(weight):
|
||||||
|
gate = torch.nn.Module()
|
||||||
|
gate.weight = weight
|
||||||
|
gate.top_k = 2
|
||||||
|
gate.norm_topk_prob = True
|
||||||
|
|
||||||
|
moe_block = torch.nn.Module()
|
||||||
|
moe_block.gate = gate
|
||||||
|
return moe_block
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sigmoid_moe_block(weight, bias):
|
||||||
|
gate = torch.nn.Module()
|
||||||
|
gate.weight = weight
|
||||||
|
gate.e_score_correction_bias = bias
|
||||||
|
|
||||||
|
moe_block = torch.nn.Module()
|
||||||
|
moe_block.gate = gate
|
||||||
|
moe_block.top_k = 2
|
||||||
|
moe_block.n_routed_experts = weight.shape[0]
|
||||||
|
moe_block.n_group = 1
|
||||||
|
moe_block.norm_topk_prob = True
|
||||||
|
moe_block.routed_scaling_factor = 1.0
|
||||||
|
return moe_block
|
||||||
|
|
||||||
|
|
||||||
|
class TestSoftmaxTopkRoutingGradcheck:
|
||||||
|
"""Numerical gradient verification for softmax_topk_routing."""
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_gate_weight(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||||
|
|
||||||
|
def fn(weight):
|
||||||
|
moe_block = _make_softmax_moe_block(weight)
|
||||||
|
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_hidden_states(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32)
|
||||||
|
moe_block = _make_softmax_moe_block(weight)
|
||||||
|
|
||||||
|
def fn(hidden):
|
||||||
|
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_router_logits(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||||
|
|
||||||
|
def fn(weight):
|
||||||
|
moe_block = _make_softmax_moe_block(weight)
|
||||||
|
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
|
||||||
|
return router_logits
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_norm_variant(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||||
|
|
||||||
|
def fn(weight):
|
||||||
|
moe_block = _make_softmax_moe_block(weight)
|
||||||
|
moe_block.gate.norm_topk_prob = False
|
||||||
|
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSigmoidTopkRoutingGradcheck:
|
||||||
|
"""Numerical gradient verification for sigmoid_topk_routing."""
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_gate_weight(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||||
|
bias = torch.zeros(E, dtype=torch.float32)
|
||||||
|
|
||||||
|
def fn(weight):
|
||||||
|
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||||
|
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_hidden_states(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32)
|
||||||
|
bias = torch.zeros(E, dtype=torch.float32)
|
||||||
|
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||||
|
|
||||||
|
def fn(hidden):
|
||||||
|
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(
|
||||||
|
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gradcheck_wrt_bias(self):
|
||||||
|
T, H, E = 4, 8, 4
|
||||||
|
|
||||||
|
hidden = torch.randn(T, H, dtype=torch.float32)
|
||||||
|
weight = torch.randn(E, H, dtype=torch.float32)
|
||||||
|
|
||||||
|
def fn(bias):
|
||||||
|
moe_block = _make_sigmoid_moe_block(weight, bias)
|
||||||
|
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
|
||||||
|
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)
|
||||||
@@ -18,6 +18,7 @@ Unit tests for SwanLab Integration Plugin.
|
|||||||
Tests conflict detection, configuration validation, and multi-logger warnings.
|
Tests conflict detection, configuration validation, and multi-logger warnings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -25,12 +26,11 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from transformers.utils.import_utils import _is_package_available
|
|
||||||
|
|
||||||
from axolotl.integrations.swanlab.args import SwanLabConfig
|
from axolotl.integrations.swanlab.args import SwanLabConfig
|
||||||
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
||||||
|
|
||||||
SWANLAB_INSTALLED = _is_package_available("swanlab")
|
SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
|
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ def mock_torch():
|
|||||||
mock_torch.cuda.device_count.return_value = 2
|
mock_torch.cuda.device_count.return_value = 2
|
||||||
|
|
||||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
(device + 1) * 1024 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
yield mock_torch
|
yield mock_torch
|
||||||
@@ -292,8 +292,8 @@ class TestRuntimeMetricsTracker:
|
|||||||
mock_memory_info = mock_process.memory_info.return_value
|
mock_memory_info = mock_process.memory_info.return_value
|
||||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||||
|
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
(device + 0.5) * 1024 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update memory metrics again
|
# Update memory metrics again
|
||||||
@@ -307,8 +307,8 @@ class TestRuntimeMetricsTracker:
|
|||||||
# Change mocked memory values to be higher
|
# Change mocked memory values to be higher
|
||||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||||
|
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
(device + 2) * 1024 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update memory metrics again
|
# Update memory metrics again
|
||||||
|
|||||||
56
tests/test_context_parallel_batch_size.py
Normal file
56
tests/test_context_parallel_batch_size.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Tests for batch_size calculation with context parallelism."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="cp_base_cfg")
|
||||||
|
def fixture_cp_base_cfg(min_base_cfg):
|
||||||
|
return (
|
||||||
|
DictDefault(
|
||||||
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
sequence_len=2048,
|
||||||
|
num_epochs=1,
|
||||||
|
flash_attention=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextParallelBatchSize:
|
||||||
|
"""Verify batch_size scales by effective dp world_size when using context parallelism."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"world_size, context_parallel_size, expected_batch_size",
|
||||||
|
[
|
||||||
|
(4, 1, 32), # no CP: 2*4*4 = 32
|
||||||
|
(4, 2, 16), # CP=2: 2*4*(4//2) = 16
|
||||||
|
(4, 4, 8), # CP=4: 2*4*(4//4) = 8
|
||||||
|
(2, 2, 8), # CP=ws: 2*4*(2//2) = 8 (no scaling)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_size_with_context_parallelism(
|
||||||
|
self,
|
||||||
|
cp_base_cfg,
|
||||||
|
monkeypatch,
|
||||||
|
world_size,
|
||||||
|
context_parallel_size,
|
||||||
|
expected_batch_size,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("WORLD_SIZE", str(world_size))
|
||||||
|
# Mock ring_flash_attn since it's not installable on CPU,
|
||||||
|
# but required by schema validation when context_parallel_size > 1.
|
||||||
|
if "ring_flash_attn" not in sys.modules:
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "ring_flash_attn", types.ModuleType("ring_flash_attn")
|
||||||
|
)
|
||||||
|
cp_base_cfg["context_parallel_size"] = context_parallel_size
|
||||||
|
cfg = validate_config(cp_base_cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
assert cfg.batch_size == expected_batch_size
|
||||||
55
tests/test_tensor_parallel_batch_size.py
Normal file
55
tests/test_tensor_parallel_batch_size.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Tests for batch_size calculation with tensor parallelism."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import addict
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="tp_base_cfg")
|
||||||
|
def fixture_tp_base_cfg(min_base_cfg):
|
||||||
|
return (
|
||||||
|
DictDefault(
|
||||||
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
sequence_len=2048,
|
||||||
|
num_epochs=1,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorParallelBatchSize:
|
||||||
|
"""Verify batch_size scales by effective dp world_size when using tensor parallelism."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"world_size, tensor_parallel_size, expected_batch_size",
|
||||||
|
[
|
||||||
|
(4, 1, 32), # no TP: 2*4*4 = 32
|
||||||
|
(4, 2, 16), # TP=2: 2*4*(4//2) = 16
|
||||||
|
(4, 4, 8), # TP=4: 2*4*(4//4) = 8
|
||||||
|
(2, 2, 8), # TP=ws: 2*4*(2//2) = 8 (no scaling)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_size_with_tensor_parallelism(
|
||||||
|
self,
|
||||||
|
tp_base_cfg,
|
||||||
|
monkeypatch,
|
||||||
|
world_size,
|
||||||
|
tensor_parallel_size,
|
||||||
|
expected_batch_size,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("WORLD_SIZE", str(world_size))
|
||||||
|
tp_base_cfg["tensor_parallel_size"] = tensor_parallel_size
|
||||||
|
cfg = validate_config(tp_base_cfg)
|
||||||
|
# Mock load_model_config to avoid downloading the model and to bypass
|
||||||
|
# the tie_word_embeddings validation that blocks TP > 1.
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.config.load_model_config",
|
||||||
|
return_value=addict.Dict({"model_type": "llama"}),
|
||||||
|
):
|
||||||
|
normalize_config(cfg)
|
||||||
|
assert cfg.batch_size == expected_batch_size
|
||||||
@@ -84,7 +84,8 @@ class TestTokenizers:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
assert "LlamaTokenizer" in tokenizer.__class__.__name__
|
||||||
|
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
|
||||||
assert len(tokenizer) == 32001
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||||
|
|||||||
156
tests/utils/schemas/validation/test_moe_quant.py
Normal file
156
tests/utils/schemas/validation/test_moe_quant.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def gpu_caps():
|
||||||
|
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def env_caps():
|
||||||
|
return {"torch_version": "2.7.0"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizeMoeExpertsValidation:
|
||||||
|
"""Test suite for quantize_moe_experts config validator."""
|
||||||
|
|
||||||
|
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts without adapter should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="requires adapter"):
|
||||||
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
|
||||||
|
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
||||||
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
|
||||||
|
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts with qlora + 4bit should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is True
|
||||||
|
|
||||||
|
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts with lora + 8bit should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="lora",
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is True
|
||||||
|
|
||||||
|
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts=false should not check adapter/quantization."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=False,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is False
|
||||||
|
|
||||||
|
def test_rejects_lora_target_linear(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts with lora_target_linear should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
lora_target_linear=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="lora_target_linear is not compatible"):
|
||||||
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
|
||||||
|
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts should default to false."""
|
||||||
|
cfg = DictDefault({}) | min_base_cfg
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoraTargetParametersDropout:
|
||||||
|
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
||||||
|
|
||||||
|
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
||||||
|
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
adapter="lora",
|
||||||
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||||
|
lora_dropout=0.1,
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_zero_dropout_passes(self, min_base_cfg):
|
||||||
|
"""lora_dropout=0 with lora_target_parameters should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
adapter="lora",
|
||||||
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||||
|
lora_dropout=0.0,
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg)
|
||||||
|
assert result["lora_dropout"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPeftPatchIdempotency:
|
||||||
|
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
||||||
|
|
||||||
|
def test_double_call_does_not_stack_wrappers(self):
|
||||||
|
"""Calling patch twice should not double-wrap _inject_parameters."""
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
patch_peft_target_parameters_matching,
|
||||||
|
)
|
||||||
|
|
||||||
|
original = BaseTuner._inject_parameters
|
||||||
|
try:
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
first_patched = BaseTuner._inject_parameters
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
second_patched = BaseTuner._inject_parameters
|
||||||
|
# Should be same function, not double-wrapped
|
||||||
|
assert first_patched is second_patched
|
||||||
|
finally:
|
||||||
|
BaseTuner._inject_parameters = original
|
||||||
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||||
Reference in New Issue
Block a user