Compare commits
35 Commits
v0.15.0
...
fix/cp-was
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
255c5b90ca | ||
|
|
038ffe3f26 | ||
|
|
c13cb7c853 | ||
|
|
b3823cc6b0 | ||
|
|
113d275bd9 | ||
|
|
7920fe74ec | ||
|
|
1fc86d5295 | ||
|
|
bb483ad4c4 | ||
|
|
163bd4dd5a | ||
|
|
f291ac029c | ||
|
|
5ef3f28340 | ||
|
|
999b3fec2e | ||
|
|
8f3fb517b3 | ||
|
|
830e9f7eaf | ||
|
|
d230cbbde3 | ||
|
|
a098df527b | ||
|
|
7da5f94379 | ||
|
|
4a5876df7a | ||
|
|
defee62d99 | ||
|
|
f56efdb4ab | ||
|
|
d8a646c80d | ||
|
|
a806704e94 | ||
|
|
d8a05744d7 | ||
|
|
ff77fa2488 | ||
|
|
e1ff756245 | ||
|
|
083c5a0421 | ||
|
|
79908b3c6e | ||
|
|
819b157c7b | ||
|
|
fccc712dae | ||
|
|
23ad40bdd5 | ||
|
|
cf4d550c88 | ||
|
|
43b1c80aa6 | ||
|
|
a36aaa70ce | ||
|
|
80f7088ad1 | ||
|
|
46b9f40f2a |
4
.github/CONTRIBUTING.md
vendored
4
.github/CONTRIBUTING.md
vendored
@@ -68,7 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
|
||||
|
||||
### Code Style
|
||||
|
||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
Use the pre-commit linter to ensure that your code is formatted consistently.
|
||||
```bash
|
||||
@@ -83,6 +83,6 @@ Write clear and concise commit messages that briefly describe the changes made i
|
||||
|
||||
- [GitHub Help](https://help.github.com/)
|
||||
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
|
||||
- [{codestyle}]({URLofCodestyle})
|
||||
- [Ruff](https://docs.astral.sh/ruff/)
|
||||
|
||||
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
|
||||
|
||||
19
.github/workflows/base.yml
vendored
19
.github/workflows/base.yml
vendored
@@ -15,6 +15,9 @@ on:
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
@@ -124,7 +127,7 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -132,7 +135,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
@@ -173,6 +176,14 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -239,7 +250,7 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -247,7 +258,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -13,6 +13,9 @@ on:
|
||||
- ".pre-commit-config.yaml"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
|
||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -110,6 +113,12 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -174,6 +183,7 @@ jobs:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
@@ -259,6 +269,7 @@ jobs:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
@@ -266,6 +277,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
@@ -326,6 +343,7 @@ jobs:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
|
||||
21
.github/workflows/multi-gpu-e2e.yml
vendored
21
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
@@ -19,6 +20,9 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
@@ -35,13 +39,13 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
@@ -77,8 +81,9 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run -m cicd.multigpu
|
||||
|
||||
3
.github/workflows/nightlies.yml
vendored
3
.github/workflows/nightlies.yml
vendored
@@ -5,6 +5,9 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -5,6 +5,8 @@ on:
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
8
.github/workflows/preview-docs.yml
vendored
8
.github/workflows/preview-docs.yml
vendored
@@ -14,14 +14,8 @@ on:
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
deployments: write
|
||||
issues: write
|
||||
discussions: write
|
||||
pages: write
|
||||
contents: read
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
preview:
|
||||
|
||||
9
.github/workflows/pypi.yml
vendored
9
.github/workflows/pypi.yml
vendored
@@ -3,9 +3,11 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -28,7 +30,8 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
contents: read
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
@@ -46,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Update version in VERSION file
|
||||
run: |
|
||||
|
||||
18
.github/workflows/tests-nightly.yml
vendored
18
.github/workflows/tests-nightly.yml
vendored
@@ -3,6 +3,13 @@ on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -27,7 +34,7 @@ jobs:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
@@ -35,7 +42,6 @@ jobs:
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
@@ -60,7 +66,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -153,8 +159,9 @@ jobs:
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
docker-e2e-multigpu-tests:
|
||||
@@ -195,7 +202,8 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
|
||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -28,6 +28,9 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
|
||||
@@ -55,7 +58,7 @@ jobs:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
@@ -303,9 +306,10 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -371,9 +375,10 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -413,7 +418,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2026/03:
|
||||
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||
- 2026/02:
|
||||
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||
@@ -75,7 +75,7 @@ Features:
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
208
benchmarks/bench_entropy.py
Normal file
208
benchmarks/bench_entropy.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
|
||||
"""Original chunked implementation (reference)."""
|
||||
original_shape = logits.shape[:-1]
|
||||
num_classes = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, num_classes)
|
||||
entropies = []
|
||||
for chunk in flat_logits.split(chunk_size, dim=0):
|
||||
logps = F.log_softmax(chunk, dim=-1)
|
||||
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
||||
entropies.append(chunk_entropy)
|
||||
return torch.cat(entropies, dim=0).reshape(original_shape)
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, logits, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
out = fn(logits, chunk_size=128)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
del out
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, logits, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(logits, chunk_size=128)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_contiguous():
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(1, 16384),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
t_orig = profile_time(entropy_from_logits_original, logits)
|
||||
t_triton = profile_time(entropy_from_logits, logits)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(entropy_from_logits_original, logits)
|
||||
m_triton = profile_memory(entropy_from_logits, logits)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_noncontiguous():
|
||||
print("\n" + "=" * 60)
|
||||
print(
|
||||
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(4, 2048, "transpose"),
|
||||
(4, 8192, "transpose"),
|
||||
(8, 2048, "transpose"),
|
||||
(4, 4096, "slice_batch"),
|
||||
]
|
||||
|
||||
for B, L, method in configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
if method == "transpose":
|
||||
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw.transpose(0, 1)
|
||||
raw_gb = L * B * V * 2 / 1e9
|
||||
elif method == "slice_batch":
|
||||
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw[::2]
|
||||
raw_gb = B * 2 * L * V * 2 / 1e9
|
||||
else:
|
||||
continue
|
||||
|
||||
if raw_gb > 28:
|
||||
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
|
||||
del raw, logits_nc
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
def original_with_copy(logits, chunk_size=128):
|
||||
return entropy_from_logits_original(
|
||||
logits.contiguous(), chunk_size=chunk_size
|
||||
)
|
||||
|
||||
t_orig = profile_time(original_with_copy, logits_nc)
|
||||
t_triton = profile_time(entropy_from_logits, logits_nc)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" orig+copy: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton-strided:{fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(original_with_copy, logits_nc)
|
||||
m_triton = profile_memory(entropy_from_logits, logits_nc)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" orig+copy: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton-strided:{fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del raw, logits_nc
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_contiguous()
|
||||
benchmark_noncontiguous()
|
||||
284
benchmarks/bench_scattermoe_lora.py
Normal file
284
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
191
benchmarks/bench_selective_logsoftmax.py
Normal file
191
benchmarks/bench_selective_logsoftmax.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
selective_log_softmax_original,
|
||||
)
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
fn(*args)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(*args)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_forward():
|
||||
print("=" * 60)
|
||||
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
||||
t_triton = profile_time(selective_log_softmax, (logits, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
||||
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_backward():
|
||||
print("\n" + "=" * 60)
|
||||
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
def fwd_bwd_original(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax_original(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
def fwd_bwd_triton(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 20:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_orig = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
||||
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" FWD+BWD TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
||||
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" FWD+BWD MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits_orig, logits_tri, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_forward()
|
||||
benchmark_backward()
|
||||
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
|
||||
11
cicd/cicd.sh
11
cicd/cicd.sh
@@ -3,11 +3,12 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
hf download "NousResearch/Meta-Llama-3-8B"
|
||||
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
hf download "microsoft/Phi-4-reasoning"
|
||||
hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
|
||||
@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
|
||||
@@ -13,9 +13,10 @@ sdp_attention: true
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention 2
|
||||
## Flash Attention
|
||||
|
||||
Uses efficient kernels to compute attention.
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
@@ -23,11 +24,9 @@ flash_attention: true
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Nvidia
|
||||
### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
#### Flash Attention 3
|
||||
### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
@@ -50,6 +50,44 @@ cd flash-attention/hopper
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
@@ -13,12 +13,14 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Mistral-Small-4](#sec-mistral-small-4)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [Qwen3.5](#sec-qwen3-5)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
@@ -108,6 +110,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
### Mistral-Small-4 {#sec-mistral-small-4}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
```
|
||||
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -184,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3.5 {#sec-qwen3-5}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
chat_template: qwen3_5
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
207
docs/rlhf.qmd
207
docs/rlhf.qmd
@@ -721,6 +721,213 @@ trl:
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
|
||||
#### Async GRPO
|
||||
|
||||
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true # Enable data producer protocol
|
||||
use_vllm: true
|
||||
async_prefetch: true # Generate rollouts in background thread
|
||||
prefetch_depth: 1 # Number of rollouts to prefetch
|
||||
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
|
||||
:::
|
||||
|
||||
##### vLLM LoRA Sync
|
||||
|
||||
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
vllm_lora_sync: true # Enable native LoRA sync
|
||||
```
|
||||
|
||||
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
```
|
||||
|
||||
Then start training on a separate GPU:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
|
||||
:::
|
||||
|
||||
##### Streaming Partial Batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
```
|
||||
|
||||
##### Importance Sampling Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true # Enable IS correction
|
||||
importance_sampling_level: token # 'token' or 'sequence'
|
||||
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
|
||||
```
|
||||
|
||||
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
|
||||
- `importance_sampling_level: sequence` applies per-sequence IS ratios
|
||||
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
|
||||
|
||||
##### Replay Buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100 # Max cached groups (0 = disabled)
|
||||
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
|
||||
:::
|
||||
|
||||
##### Deferred Re-rolling
|
||||
|
||||
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
##### Zero-Advantage Batch Skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
##### Parallel Reward Workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
|
||||
```
|
||||
|
||||
##### Full Async GRPO Example
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.35
|
||||
dtype: auto
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_data_producer: true
|
||||
use_vllm: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
vllm_lora_sync: true
|
||||
streaming_partial_batch: true
|
||||
vllm_importance_sampling_correction: true
|
||||
off_policy_mask_threshold: 0.5
|
||||
importance_sampling_level: token
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
reroll_start_fraction: 0.5
|
||||
replay_buffer_size: 100
|
||||
reward_num_workers: 4
|
||||
skip_zero_advantage_batches: true
|
||||
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-TIR
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
max_steps: 500
|
||||
learning_rate: 1e-5
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
##### Multi-GPU Async GRPO
|
||||
|
||||
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
|
||||
|
||||
**FSDP:**
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
**DeepSpeed ZeRO-3:**
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for ZeRO-3
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
|
||||
:::
|
||||
|
||||
### GDPO
|
||||
|
||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -27,6 +24,11 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
|
||||
85
examples/mistral4/README.md
Normal file
85
examples/mistral4/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Finetune Mistral Small 4 with Axolotl
|
||||
|
||||
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
Note: Training this model requires weights in BF16 which we will link to later.
|
||||
Users interested in training can convert / descale the existing FP8 weights.
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
3. Install transformers from main
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
```bash
|
||||
# text-only
|
||||
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
|
||||
axolotl train examples/mistral4/fft-text.yml
|
||||
|
||||
# text + vision
|
||||
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
|
||||
axolotl train examples/mistral4/fft-vision.yml
|
||||
```
|
||||
|
||||
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
|
||||
|
||||
## Reasoning Effort
|
||||
|
||||
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
|
||||
|
||||
- `"none"` — instruct mode (default)
|
||||
- `"high"` — reasoning mode with explicit thinking steps
|
||||
|
||||
Pass it via `chat_template_kwargs` under your dataset config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: chat_template
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
## Thinking Support
|
||||
|
||||
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
|
||||
|
||||
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/thinking-dataset
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
thinking: thinking
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
|
||||
|
||||
## Tips
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
58
examples/mistral4/fft-text.yml
Normal file
58
examples/mistral4/fft-text.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_sonicmoe: true
|
||||
|
||||
# only train language model layers, freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- model.language_model.*
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
57
examples/mistral4/fft-vision.yml
Normal file
57
examples/mistral4/fft-vision.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_sonicmoe: true
|
||||
|
||||
# vision requirements
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
58
examples/mistral4/qlora-text.yml
Normal file
58
examples/mistral4/qlora-text.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# uncomment to train on expert layers
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
# lora_mlp_kernel: false
|
||||
# lora_qkv_kernel: false
|
||||
# lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
63
examples/mistral4/qlora-vision.yml
Normal file
63
examples/mistral4/qlora-vision.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
# vision chat template requirements
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# uncomment to train on expert layers
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
# lora_mlp_kernel: false
|
||||
# lora_qkv_kernel: false
|
||||
# lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
57
examples/nemotron/nemotron-mini-4b-qlora.yaml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: nvidia/Nemotron-Mini-4B-Instruct
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/nemotron-mini-4b-qlora
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
special_tokens:
|
||||
59
examples/qwen3.5/27b-fft.yaml
Normal file
59
examples/qwen3.5/27b-fft.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
# Freeze vision encoder
|
||||
unfrozen_parameters:
|
||||
- model\.language_model\..*
|
||||
- lm_head\..*
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
49
examples/qwen3.5/9b-fft-vision.yaml
Normal file
@@ -0,0 +1,49 @@
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,10 +1,6 @@
|
||||
base_model: Qwen/Qwen3.5-7B
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
|
||||
# Vision and text tokens are processed together by the same transformer layers.
|
||||
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
|
||||
|
||||
# These 3 lines are required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
@@ -1,15 +1,20 @@
|
||||
# Finetune Qwen3.5 with Axolotl
|
||||
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||
|
||||
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type |
|
||||
|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
|
||||
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -29,23 +34,31 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
|
||||
axolotl train examples/qwen3.5/27b-qlora.yaml
|
||||
|
||||
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
|
||||
axolotl train examples/qwen3.5/27b-fft.yaml
|
||||
|
||||
# MoE 35B-A3B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
|
||||
|
||||
# MoE 122B-A10B text-only (QLoRA)
|
||||
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
|
||||
|
||||
# 7B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/7b-lora-vision.yaml
|
||||
# 9B vision+text (LoRA, multimodal dataset)
|
||||
axolotl train examples/qwen3.5/9b-lora-vision.yaml
|
||||
|
||||
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
|
||||
axolotl train examples/qwen3.5/9b-fft-vision.yaml
|
||||
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
|
||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.8.8
|
||||
mistral-common==1.10.0
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
)
|
||||
|
||||
@@ -90,9 +90,8 @@ class ModalCloud(Cloud):
|
||||
# grab the sha256 hash from docker hub for this image+tag
|
||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||
try:
|
||||
manifest = subprocess.check_output( # nosec B602
|
||||
f"docker manifest inspect {docker_image}",
|
||||
shell=True,
|
||||
manifest = subprocess.check_output( # nosec
|
||||
["docker", "manifest", "inspect", docker_image],
|
||||
).decode("utf-8")
|
||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||
except subprocess.CalledProcessError:
|
||||
|
||||
@@ -11,7 +11,7 @@ from urllib.parse import urlparse
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
@@ -300,7 +300,7 @@ def load_cfg(
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
except:
|
||||
except (RuntimeError, AssertionError):
|
||||
gpu_version = None
|
||||
|
||||
prepare_plugins(cfg)
|
||||
@@ -310,6 +310,7 @@ def load_cfg(
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"tf32": is_torch_tf32_available(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
|
||||
@@ -196,12 +196,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,18 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -68,7 +79,7 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
base_kwargs = dict(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
@@ -78,7 +89,21 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
if serve_module == "axolotl.scripts.vllm_serve_lora":
|
||||
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
|
||||
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
**base_kwargs,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"mistral4": "Mistral4MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
|
||||
@@ -67,7 +67,7 @@ class JsonToJsonlConverter:
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
def convert(self, input_file_path):
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||
|
||||
@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
|
||||
@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
|
||||
# config reflects this regardless of how the model was instantiated.
|
||||
if (
|
||||
self.cfg.reward_model
|
||||
and getattr(self.model.config, "num_labels", None) != 1
|
||||
):
|
||||
self.model.config.num_labels = 1
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
|
||||
@@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1,
|
||||
async_grpo=async_grpo,
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
@@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
training_args_cls = GRPOStrategy.get_training_args_class(
|
||||
async_grpo=async_grpo
|
||||
)
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
@@ -217,13 +234,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
|
||||
# transformers' validate_quantization_for_training blocks FP8 because
|
||||
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
|
||||
# (base weights stay frozen in FP8).
|
||||
_orig_validate_quant = None
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and hasattr(self.model, "is_quantized")
|
||||
and self.model.is_quantized
|
||||
):
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_orig_validate_quant = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
|
||||
try:
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
finally:
|
||||
if _orig_validate_quant is not None:
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_trainer_module.validate_quantization_for_training = (
|
||||
_orig_validate_quant
|
||||
)
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
|
||||
@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
@@ -27,14 +28,31 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
cls,
|
||||
sequence_parallel: bool,
|
||||
async_grpo: bool = False,
|
||||
) -> (
|
||||
type[AxolotlGRPOTrainer]
|
||||
| type[AxolotlGRPOSequenceParallelTrainer]
|
||||
| type[AxolotlAsyncGRPOTrainer]
|
||||
):
|
||||
if sequence_parallel and async_grpo:
|
||||
raise ValueError(
|
||||
"sequence_parallel and async_grpo cannot both be enabled. "
|
||||
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
|
||||
)
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
||||
def get_training_args_class(
|
||||
cls, async_grpo: bool = False
|
||||
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOConfig
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
@@ -124,13 +142,63 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
|
||||
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
if getattr(trl, "use_data_producer", None) is not None:
|
||||
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
|
||||
if getattr(trl, "async_prefetch", None) is not None:
|
||||
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "prefetch_depth", None) is not None:
|
||||
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "streaming_partial_batch", None) is not None:
|
||||
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
|
||||
if getattr(trl, "streaming_min_groups", None) is not None:
|
||||
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
|
||||
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
|
||||
trl.vllm_importance_sampling_correction
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
|
||||
trl.vllm_importance_sampling_mode
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
|
||||
trl.vllm_importance_sampling_cap
|
||||
)
|
||||
if getattr(trl, "off_policy_mask_threshold", None) is not None:
|
||||
grpo_args_kwargs["off_policy_mask_threshold"] = (
|
||||
trl.off_policy_mask_threshold
|
||||
)
|
||||
if getattr(trl, "use_bias_correction_kl", None) is not None:
|
||||
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
|
||||
|
||||
# Fast Async GRPO fields
|
||||
if getattr(trl, "reward_num_workers", None) is not None:
|
||||
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
|
||||
if getattr(trl, "replay_buffer_size", None) is not None:
|
||||
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
|
||||
if getattr(trl, "replay_recompute_logps", None) is not None:
|
||||
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
|
||||
if getattr(trl, "reroll_start_fraction", None) is not None:
|
||||
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
|
||||
if getattr(trl, "reroll_max_groups", None) is not None:
|
||||
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
|
||||
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
|
||||
grpo_args_kwargs["skip_zero_advantage_batches"] = (
|
||||
trl.skip_zero_advantage_batches
|
||||
)
|
||||
if getattr(trl, "vllm_lora_sync", None) is not None:
|
||||
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
|
||||
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
2657
src/axolotl/core/trainers/grpo/async_trainer.py
Normal file
2657
src/axolotl/core/trainers/grpo/async_trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
768
src/axolotl/core/trainers/grpo/fast_async_trainer.py
Normal file
768
src/axolotl/core/trainers/grpo/fast_async_trainer.py
Normal file
@@ -0,0 +1,768 @@
|
||||
# Copyright 2020-2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Experimental GRPO extensions: parallel reward workers, replay buffer,
|
||||
deferred re-roll, and zero-advantage skipping.
|
||||
|
||||
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
|
||||
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
|
||||
_pre_produce_hook) defined in the base classes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from trl import GRPOTrainer
|
||||
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncGRPOConfig,
|
||||
AsyncGRPOTrainer,
|
||||
GRPODataProducer,
|
||||
)
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class FastAsyncGRPOConfig(AsyncGRPOConfig):
|
||||
"""GRPOConfig with additional experimental parameters."""
|
||||
|
||||
reward_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
|
||||
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
|
||||
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
|
||||
},
|
||||
)
|
||||
replay_buffer_size: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
|
||||
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
|
||||
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
replay_recompute_logps: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
|
||||
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
|
||||
"Only relevant when replay_buffer_size > 0."
|
||||
},
|
||||
)
|
||||
reroll_start_fraction: float = field(
|
||||
default=0.5,
|
||||
metadata={
|
||||
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
reroll_max_groups: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
|
||||
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
skip_zero_advantage_batches: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
|
||||
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
|
||||
"logged with skipped_zero_adv_batches=1 for monitoring."
|
||||
},
|
||||
)
|
||||
vllm_lora_sync: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
|
||||
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
|
||||
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
|
||||
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended data producer with re-roll injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RerollDataProducer(GRPODataProducer):
|
||||
"""GRPODataProducer that injects re-roll candidates into prompt batches.
|
||||
|
||||
Reads from the trainer's ``_reroll_buffer`` (populated by
|
||||
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
|
||||
last N prompt groups with previously-failed prompts.
|
||||
"""
|
||||
|
||||
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
|
||||
trainer = self._trainer
|
||||
reroll_buf = getattr(trainer, "_reroll_buffer", None)
|
||||
reroll_lock = getattr(trainer, "_reroll_lock", None)
|
||||
if reroll_buf is None or reroll_lock is None:
|
||||
return inputs
|
||||
|
||||
max_steps = getattr(trainer.args, "max_steps", -1)
|
||||
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
|
||||
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
|
||||
reroll_start_step = (
|
||||
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
|
||||
)
|
||||
|
||||
if global_step < reroll_start_step:
|
||||
return inputs
|
||||
|
||||
with reroll_lock:
|
||||
n_to_take = min(max_groups, len(reroll_buf))
|
||||
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
|
||||
|
||||
if reroll_prompts:
|
||||
num_gen = self._num_generations
|
||||
n_groups = len(inputs) // num_gen
|
||||
for i, reroll_prompt in enumerate(reroll_prompts):
|
||||
group_idx = n_groups - 1 - i
|
||||
if group_idx < 0:
|
||||
break
|
||||
start = group_idx * num_gen
|
||||
for j in range(num_gen):
|
||||
inputs[start + j] = reroll_prompt
|
||||
logger.info(
|
||||
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
|
||||
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistent reward subprocess pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _persistent_reward_worker(conn):
|
||||
"""Long-lived reward worker. Receives work items, returns results."""
|
||||
while True:
|
||||
try:
|
||||
msg = conn.recv()
|
||||
except EOFError:
|
||||
break
|
||||
if msg is None: # Shutdown signal
|
||||
break
|
||||
(
|
||||
reward_funcs,
|
||||
prompts,
|
||||
completions,
|
||||
completion_ids_list,
|
||||
inputs,
|
||||
reward_func_names,
|
||||
) = msg
|
||||
try:
|
||||
keys = [
|
||||
key
|
||||
for key in inputs[0]
|
||||
if key not in ["prompt", "completion", "completion_ids"]
|
||||
]
|
||||
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
||||
results = []
|
||||
for reward_func, _reward_func_name in zip(
|
||||
reward_funcs, reward_func_names, strict=True
|
||||
):
|
||||
output = reward_func(
|
||||
prompts=prompts,
|
||||
completions=completions,
|
||||
completion_ids=completion_ids_list,
|
||||
**reward_kwargs,
|
||||
)
|
||||
results.append(
|
||||
[float(r) if r is not None else float("nan") for r in output]
|
||||
)
|
||||
conn.send(results)
|
||||
except Exception:
|
||||
conn.send(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended trainer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
|
||||
"""GRPOTrainer with experimental extensions.
|
||||
|
||||
Adds:
|
||||
- Parallel reward subprocess workers (``reward_num_workers``)
|
||||
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
|
||||
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
|
||||
- Zero-advantage micro-batch skipping
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# These must be initialized before super().__init__() because
|
||||
# _create_data_producer (called during super().__init__) needs them.
|
||||
self._reroll_buffer: list = []
|
||||
self._reroll_lock = threading.Lock()
|
||||
|
||||
# Temporarily suppress the base class's Liger + OPSM validation check,
|
||||
# since this subclass supports it via a custom compute_liger_loss override.
|
||||
grpo_args = kwargs.get("args")
|
||||
if grpo_args is None:
|
||||
for a in args:
|
||||
if hasattr(a, "off_policy_mask_threshold"):
|
||||
grpo_args = a
|
||||
break
|
||||
saved_threshold = None
|
||||
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
|
||||
saved_threshold = grpo_args.off_policy_mask_threshold
|
||||
grpo_args.off_policy_mask_threshold = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if saved_threshold is not None:
|
||||
grpo_args.off_policy_mask_threshold = saved_threshold
|
||||
self.off_policy_mask_threshold = saved_threshold
|
||||
|
||||
# Replay buffer
|
||||
if getattr(self.args, "replay_buffer_size", 0) > 0:
|
||||
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
|
||||
else:
|
||||
self._replay_buffer = None
|
||||
self._replay_recompute_logps = getattr(
|
||||
self.args, "replay_recompute_logps", True
|
||||
)
|
||||
|
||||
# Reward worker pool (lazy-initialized)
|
||||
self._reward_workers = None
|
||||
|
||||
# -- Factory override: use RerollDataProducer ----------------------------
|
||||
|
||||
def _create_data_producer(self, args, train_dataset):
|
||||
"""Override to use RerollDataProducer for re-roll prompt injection."""
|
||||
from axolotl.core.trainers.grpo.async_trainer import (
|
||||
AsyncDataProducer,
|
||||
ProducerConfig,
|
||||
)
|
||||
|
||||
producer_config = ProducerConfig(
|
||||
mini_epochs=args.num_iterations,
|
||||
max_rollouts=None,
|
||||
eval_during_produce=False,
|
||||
empty_cache_before_produce=True,
|
||||
empty_cache_after_produce=True,
|
||||
async_prefetch=args.async_prefetch,
|
||||
prefetch_depth=args.prefetch_depth,
|
||||
)
|
||||
data_producer = RerollDataProducer(
|
||||
config=producer_config,
|
||||
prompt_dataset=train_dataset,
|
||||
num_generations=self.num_generations,
|
||||
generation_batch_size=args.generation_batch_size,
|
||||
train_batch_size=args.per_device_train_batch_size,
|
||||
steps_per_generation=args.steps_per_generation,
|
||||
shuffle_dataset=self.shuffle_dataset,
|
||||
seed=args.seed,
|
||||
)
|
||||
data_producer.set_trainer(self)
|
||||
if args.async_prefetch:
|
||||
data_producer = AsyncDataProducer(
|
||||
data_producer,
|
||||
background_produce_kwargs={"skip_policy_logps": True},
|
||||
)
|
||||
return data_producer
|
||||
|
||||
# -- Reward worker pool --------------------------------------------------
|
||||
|
||||
def _get_reward_workers(self):
|
||||
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
|
||||
import multiprocessing as _mp
|
||||
|
||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||
if num_workers < 1:
|
||||
num_workers = 1
|
||||
|
||||
if self._reward_workers is not None:
|
||||
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
|
||||
if alive and len(self._reward_workers) == num_workers:
|
||||
return self._reward_workers
|
||||
self._shutdown_reward_workers()
|
||||
|
||||
workers = []
|
||||
for _ in range(num_workers):
|
||||
parent_conn, child_conn = _mp.Pipe()
|
||||
proc = _mp.Process(
|
||||
target=_persistent_reward_worker, args=(child_conn,), daemon=True
|
||||
)
|
||||
proc.start()
|
||||
child_conn.close()
|
||||
workers.append((parent_conn, proc))
|
||||
|
||||
self._reward_workers = workers
|
||||
return workers
|
||||
|
||||
def _shutdown_reward_workers(self):
|
||||
"""Shut down all persistent reward workers."""
|
||||
if self._reward_workers is None:
|
||||
return
|
||||
for conn, proc in self._reward_workers:
|
||||
try:
|
||||
conn.send(None)
|
||||
proc.join(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._reward_workers = None
|
||||
|
||||
# -- Hook overrides ------------------------------------------------------
|
||||
|
||||
def _compute_rewards_for_batch(
|
||||
self, inputs, prompts, completions, completion_ids_list
|
||||
):
|
||||
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
|
||||
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
|
||||
return self._collect_reward_workers(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
|
||||
"""Send reward work to subprocess workers (non-blocking).
|
||||
|
||||
Results are collected later by _collect_reward_workers, allowing GPU
|
||||
logprob computation to overlap with CPU reward computation.
|
||||
"""
|
||||
reward_can_bg = all(
|
||||
callable(rf)
|
||||
and not isinstance(rf, nn.Module)
|
||||
and not asyncio.iscoroutinefunction(rf)
|
||||
for rf in self.reward_funcs
|
||||
)
|
||||
num_workers = getattr(self.args, "reward_num_workers", 1)
|
||||
|
||||
if not reward_can_bg or num_workers <= 1:
|
||||
# Can't parallelize — store args for sync fallback in collect
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = (
|
||||
inputs,
|
||||
prompts,
|
||||
completions,
|
||||
completion_ids_list,
|
||||
)
|
||||
return
|
||||
|
||||
workers = self._get_reward_workers()
|
||||
num_generations = self.num_generations
|
||||
num_prompts = len(prompts)
|
||||
num_groups = num_prompts // num_generations
|
||||
|
||||
# Shard by prompt groups across workers
|
||||
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
|
||||
workers_used = []
|
||||
for w_idx, (conn, _proc) in enumerate(workers):
|
||||
g_start = w_idx * groups_per_worker
|
||||
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
|
||||
if g_start >= num_groups:
|
||||
break
|
||||
s_start = g_start * num_generations
|
||||
s_end = g_end * num_generations
|
||||
conn.send(
|
||||
(
|
||||
self.reward_funcs,
|
||||
prompts[s_start:s_end],
|
||||
completions[s_start:s_end],
|
||||
completion_ids_list[s_start:s_end],
|
||||
inputs[s_start:s_end],
|
||||
self.reward_func_names,
|
||||
)
|
||||
)
|
||||
workers_used.append(conn)
|
||||
|
||||
self._reward_workers_used = workers_used
|
||||
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
|
||||
|
||||
def _collect_reward_workers(
|
||||
self, inputs, prompts, completions, completion_ids_list
|
||||
):
|
||||
"""Collect reward results from subprocess workers (blocks until done)."""
|
||||
from accelerate.utils import gather
|
||||
|
||||
workers_used = getattr(self, "_reward_workers_used", None)
|
||||
args = getattr(self, "_pending_reward_args", None)
|
||||
self._reward_workers_used = None
|
||||
self._pending_reward_args = None
|
||||
|
||||
if workers_used is None:
|
||||
# Sync fallback — compute on main thread
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
device = self.accelerator.device
|
||||
num_prompts = len(args[1]) if args else len(prompts)
|
||||
|
||||
# Collect results from workers
|
||||
all_worker_results = []
|
||||
any_failed = False
|
||||
for conn in workers_used:
|
||||
result = conn.recv()
|
||||
if result is None:
|
||||
any_failed = True
|
||||
# Drain remaining workers to prevent stale results in pipes
|
||||
for remaining_conn in workers_used:
|
||||
if remaining_conn is not conn:
|
||||
try:
|
||||
remaining_conn.recv()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
all_worker_results.append(result)
|
||||
|
||||
if not any_failed:
|
||||
rewards_per_func = torch.zeros(
|
||||
num_prompts, len(self.reward_funcs), device=device
|
||||
)
|
||||
offset = 0
|
||||
for worker_result in all_worker_results:
|
||||
chunk_size = len(worker_result[0])
|
||||
for i, result in enumerate(worker_result):
|
||||
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
|
||||
result, dtype=torch.float32, device=device
|
||||
)
|
||||
offset += chunk_size
|
||||
return gather(rewards_per_func)
|
||||
|
||||
# Fallback to main thread on failure
|
||||
if args is not None:
|
||||
return self._calculate_rewards(*args)
|
||||
return self._calculate_rewards(
|
||||
inputs, prompts, completions, completion_ids_list
|
||||
)
|
||||
|
||||
def _post_advantage_hook(
|
||||
self,
|
||||
data: dict,
|
||||
rewards_per_func,
|
||||
advantages,
|
||||
inputs: list,
|
||||
num_generations: int,
|
||||
mode: str,
|
||||
s_start: int | None = None,
|
||||
s_end: int | None = None,
|
||||
is_last_chunk: bool = True,
|
||||
) -> None:
|
||||
"""Replay buffer store/replace + re-roll buffering."""
|
||||
from trl.models.utils import disable_gradient_checkpointing
|
||||
|
||||
# -- Replay buffer: store high-signal groups --
|
||||
if self._replay_buffer is not None:
|
||||
local_grouped = rewards_per_func.view(
|
||||
-1, num_generations, len(self.reward_funcs)
|
||||
)
|
||||
per_group_std = local_grouped.std(dim=1)
|
||||
has_signal = (per_group_std > 0).any(dim=1)
|
||||
offset = s_start or 0
|
||||
|
||||
if has_signal.any():
|
||||
grouped_adv = advantages.view(-1, num_generations)
|
||||
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
|
||||
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
|
||||
gi = group_idx.item()
|
||||
start = offset + gi * num_generations
|
||||
end = start + num_generations
|
||||
group_data = {}
|
||||
for key in data:
|
||||
val = data[key]
|
||||
if (
|
||||
isinstance(val, torch.Tensor)
|
||||
and val.dim() > 0
|
||||
and val.size(0) >= end
|
||||
):
|
||||
group_data[key] = val[start:end].clone()
|
||||
self._replay_buffer.add(replay_scores[gi].item(), group_data)
|
||||
|
||||
# Replace zero-signal groups with high-signal replay buffer entries
|
||||
# Only in non-streaming path (s_start is None) — streaming scores
|
||||
# groups incrementally, so replacement + logprob recompute would be
|
||||
# too expensive per chunk.
|
||||
n_replaced = 0
|
||||
if s_start is None:
|
||||
no_signal = ~has_signal
|
||||
replaced_ranges = []
|
||||
if no_signal.any() and len(self._replay_buffer) > 0:
|
||||
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
|
||||
sampled = self._replay_buffer.sample(1)
|
||||
if sampled is None:
|
||||
break
|
||||
sampled_group = sampled[0]
|
||||
gi = group_idx.item()
|
||||
start = offset + gi * num_generations
|
||||
end = start + num_generations
|
||||
for key, val in sampled_group.items():
|
||||
if key in data and isinstance(data[key], torch.Tensor):
|
||||
src = val.to(data[key].device)
|
||||
tgt_seq_len = (
|
||||
data[key].size(1) if data[key].dim() > 1 else None
|
||||
)
|
||||
if start >= data[key].size(0) or end > data[key].size(
|
||||
0
|
||||
):
|
||||
continue
|
||||
if tgt_seq_len is not None:
|
||||
if src.size(1) <= tgt_seq_len:
|
||||
data[key][start:end] = 0
|
||||
data[key][start:end, : src.size(1)] = src
|
||||
else:
|
||||
data[key][start:end] = src[:, :tgt_seq_len]
|
||||
else:
|
||||
data[key][start:end] = src
|
||||
replaced_ranges.append((start, end))
|
||||
n_replaced += 1
|
||||
|
||||
# Recompute old_per_token_logps for replayed groups
|
||||
if (
|
||||
n_replaced > 0
|
||||
and self._replay_recompute_logps
|
||||
and "old_per_token_logps" in data
|
||||
):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
disable_gradient_checkpointing(
|
||||
self.model, self.args.gradient_checkpointing_kwargs
|
||||
),
|
||||
):
|
||||
for r_start, r_end in replaced_ranges:
|
||||
r_ids = torch.cat(
|
||||
[
|
||||
data["prompt_ids"][r_start:r_end],
|
||||
data["completion_ids"][r_start:r_end],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
r_mask = torch.cat(
|
||||
[
|
||||
data["prompt_mask"][r_start:r_end],
|
||||
data["completion_mask"][r_start:r_end],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
r_logits_to_keep = data["completion_ids"].size(1)
|
||||
r_fwd_kwargs = {}
|
||||
for fk in (
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"pixel_attention_mask",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"mm_token_type_ids",
|
||||
):
|
||||
if fk in data:
|
||||
r_fwd_kwargs[fk] = data[fk]
|
||||
r_logps, _ = self._get_per_token_logps_and_entropies(
|
||||
self.model,
|
||||
r_ids,
|
||||
r_mask,
|
||||
r_logits_to_keep,
|
||||
r_end - r_start,
|
||||
**r_fwd_kwargs,
|
||||
)
|
||||
data["old_per_token_logps"][r_start:r_end] = r_logps
|
||||
|
||||
if n_replaced > 0:
|
||||
self._metrics[mode]["replay_buffer_replacements"].append(
|
||||
float(n_replaced)
|
||||
)
|
||||
|
||||
if is_last_chunk:
|
||||
self._metrics[mode]["replay_buffer_size"].append(
|
||||
float(len(self._replay_buffer))
|
||||
)
|
||||
|
||||
# -- Re-roll buffer: store failed prompts --
|
||||
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
|
||||
grouped_rewards = rewards_per_func.view(
|
||||
-1, num_generations, len(self.reward_funcs)
|
||||
)
|
||||
per_group_std = grouped_rewards.std(dim=1)
|
||||
per_group_mean = grouped_rewards.mean(dim=1)
|
||||
zero_signal = (per_group_std == 0).all(dim=1)
|
||||
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
|
||||
should_reroll = zero_signal & all_failed
|
||||
_n_buffered = 0
|
||||
with self._reroll_lock:
|
||||
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
|
||||
idx = group_idx.item() * num_generations
|
||||
if idx >= len(inputs):
|
||||
continue
|
||||
prompt_input = inputs[idx]
|
||||
self._reroll_buffer.append(prompt_input)
|
||||
_n_buffered += 1
|
||||
if _n_buffered > 0:
|
||||
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
|
||||
if is_last_chunk:
|
||||
self._metrics[mode]["reroll_buffer_size"].append(
|
||||
float(len(self._reroll_buffer))
|
||||
)
|
||||
|
||||
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
|
||||
|
||||
def compute_liger_loss(self, unwrapped_model, inputs):
|
||||
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
|
||||
|
||||
The base class Liger path doesn't support OPSM because the fused kernel
|
||||
doesn't expose per-token logprobs needed for the KL computation. This
|
||||
override computes them via chunked lm_head matmul (no grad, low memory)
|
||||
and applies the OPSM to the loss mask before calling the kernel.
|
||||
"""
|
||||
if self.args.skip_zero_advantage_batches and torch.all(
|
||||
inputs["advantages"] == 0
|
||||
):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
|
||||
return torch.tensor(
|
||||
0.0, device=inputs["advantages"].device, requires_grad=True
|
||||
)
|
||||
|
||||
if self.off_policy_mask_threshold is None:
|
||||
return super().compute_liger_loss(unwrapped_model, inputs)
|
||||
|
||||
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = (
|
||||
inputs["completion_ids"],
|
||||
inputs["completion_mask"],
|
||||
)
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1)
|
||||
|
||||
last_hidden_state = self._get_last_hidden_state(
|
||||
unwrapped_model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
inputs.get("pixel_values"),
|
||||
inputs.get("image_grid_thw"),
|
||||
inputs.get("pixel_attention_mask"),
|
||||
inputs.get("image_sizes"),
|
||||
)
|
||||
|
||||
loss_mask = (
|
||||
completion_mask
|
||||
if "tool_mask" not in inputs
|
||||
else completion_mask * inputs["tool_mask"]
|
||||
)
|
||||
|
||||
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
|
||||
lm_weight = unwrapped_model.lm_head.weight
|
||||
lm_bias = unwrapped_model.lm_head.bias
|
||||
with torch.no_grad():
|
||||
per_token_logps_chunks = []
|
||||
for i in range(last_hidden_state.size(0)):
|
||||
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
|
||||
if lm_bias is not None:
|
||||
chunk_logits = chunk_logits + lm_bias
|
||||
chunk_lps = (
|
||||
chunk_logits.float()
|
||||
.log_softmax(-1)
|
||||
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
|
||||
.squeeze(-1)
|
||||
)
|
||||
per_token_logps_chunks.append(chunk_lps)
|
||||
del chunk_logits
|
||||
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
|
||||
|
||||
advantages = inputs["advantages"]
|
||||
if advantages.dim() == 1:
|
||||
advantages_2d = advantages.unsqueeze(1)
|
||||
else:
|
||||
advantages_2d = advantages
|
||||
|
||||
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
|
||||
if sampling_per_token_logps is None:
|
||||
sampling_per_token_logps = inputs.get("old_per_token_logps")
|
||||
if sampling_per_token_logps is None:
|
||||
sampling_per_token_logps = per_token_logps
|
||||
|
||||
off_policy_mask = GRPOTrainer.get_off_policy_mask(
|
||||
advantages=advantages_2d,
|
||||
per_token_logps=per_token_logps,
|
||||
sampling_per_token_logps=sampling_per_token_logps,
|
||||
mask=loss_mask,
|
||||
off_policy_threshold=self.off_policy_mask_threshold,
|
||||
)
|
||||
loss_mask = loss_mask * off_policy_mask
|
||||
|
||||
# Call the Liger fused kernel with OPSM-modified mask
|
||||
loss, metrics = self.liger_grpo_loss(
|
||||
_input=last_hidden_state,
|
||||
lin_weight=unwrapped_model.lm_head.weight,
|
||||
selected_token_ids=completion_ids,
|
||||
attention_mask=loss_mask,
|
||||
advantages=inputs["advantages"],
|
||||
bias=unwrapped_model.lm_head.bias,
|
||||
old_per_token_logps=inputs.get("old_per_token_logps"),
|
||||
ref_per_token_logps=inputs.get("ref_per_token_logps"),
|
||||
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
|
||||
)
|
||||
|
||||
mean_kl = metrics[0] if self.beta != 0.0 else None
|
||||
clip_ratio = metrics[-1]
|
||||
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if self.beta != 0.0:
|
||||
self._metrics[mode]["kl"].append(
|
||||
self.accelerator.gather(mean_kl).mean().item()
|
||||
)
|
||||
self._metrics[mode]["clip_ratio"].append(
|
||||
self.accelerator.gather(clip_ratio).mean().item()
|
||||
)
|
||||
normalizer = (
|
||||
self.current_gradient_accumulation_steps if mode == "train" else 1.0
|
||||
)
|
||||
return loss / normalizer
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
if self.args.skip_zero_advantage_batches and torch.all(
|
||||
inputs["advantages"] == 0
|
||||
):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
|
||||
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
|
||||
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
|
||||
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
|
||||
# with a single token to create a proper computation graph.
|
||||
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
|
||||
attn = torch.ones_like(prompt_ids)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
out = model(input_ids=prompt_ids, attention_mask=attn)
|
||||
return out.logits.sum() * 0
|
||||
return super()._compute_loss(model, inputs)
|
||||
44
src/axolotl/core/trainers/grpo/replay_buffer.py
Normal file
44
src/axolotl/core/trainers/grpo/replay_buffer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
|
||||
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
|
||||
Groups are scored by signal quality (advantage magnitude * reward variance).
|
||||
When sampling, groups are drawn proportional to their scores.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
self.max_size = max_size
|
||||
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
|
||||
self._counter = 0 # unique tiebreaker for heap
|
||||
|
||||
def __len__(self):
|
||||
return len(self._heap)
|
||||
|
||||
def add(self, score: float, data: dict):
|
||||
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
|
||||
if self.max_size <= 0:
|
||||
return
|
||||
self._counter += 1
|
||||
if len(self._heap) < self.max_size:
|
||||
heapq.heappush(self._heap, (score, self._counter, data))
|
||||
elif score > self._heap[0][0]:
|
||||
heapq.heapreplace(self._heap, (score, self._counter, data))
|
||||
|
||||
def sample(self, num_samples: int) -> list[dict] | None:
|
||||
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
|
||||
if self.max_size <= 0 or not self._heap:
|
||||
return None
|
||||
|
||||
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
|
||||
scores = scores.clamp(min=1e-8) # avoid zero probabilities
|
||||
probs = scores / scores.sum()
|
||||
replacement = num_samples > len(self._heap)
|
||||
indices = torch.multinomial(
|
||||
probs, num_samples, replacement=replacement
|
||||
).tolist()
|
||||
return [self._heap[i][2] for i in indices]
|
||||
@@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
@@ -66,6 +67,19 @@ class AxolotlGRPOTrainer(
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlAsyncGRPOTrainer(
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
FastAsyncGRPOTrainer,
|
||||
):
|
||||
"""Extend AsyncGRPOTrainer with axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "async", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
|
||||
@@ -19,5 +19,4 @@ class CheckpointSaveMixin(Trainer):
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -73,8 +73,10 @@ plugins:
|
||||
- ministral3
|
||||
- mistral
|
||||
- mistral3
|
||||
- mistral4
|
||||
- mixtral
|
||||
- mllama
|
||||
- nemotron_h
|
||||
- olmo
|
||||
- olmo2
|
||||
- olmo3
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
120
src/axolotl/integrations/kernels/autotune_callback.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Give up looking for autotune data after this many training steps.
|
||||
_MAX_POLL_STEP = 5
|
||||
|
||||
|
||||
def _get_gpu_info() -> dict:
|
||||
"""Return basic GPU identification for the current device."""
|
||||
if not torch.cuda.is_available():
|
||||
return {}
|
||||
try:
|
||||
idx = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(idx)
|
||||
return {
|
||||
"gpu_name": props.name,
|
||||
"gpu_compute_capability": f"{props.major}.{props.minor}",
|
||||
"gpu_memory_bytes": props.total_memory,
|
||||
}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
def _get_smem_capacity() -> dict:
|
||||
"""Return shared memory capacity from the runtime lora_ops module."""
|
||||
try:
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
_find_lora_ops_module,
|
||||
)
|
||||
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
return {}
|
||||
fn = getattr(lora_ops, "_get_smem_capacity", None)
|
||||
if fn is None:
|
||||
return {}
|
||||
return {"smem_capacity_bytes": fn()}
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return {}
|
||||
|
||||
|
||||
class AutotuneReportCallback(TrainerCallback):
|
||||
"""Reports Triton kernel autotune selections via telemetry.
|
||||
|
||||
Fires **once** after the first training step completes (step 1), at
|
||||
which point the forward and backward passes have both run and the
|
||||
autotuned kernels have populated their caches. If for some reason
|
||||
the caches are still empty (e.g. the kernel was never invoked), the
|
||||
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
|
||||
then stops polling.
|
||||
|
||||
After reporting (or giving up) every subsequent ``on_step_end``
|
||||
call short-circuits on the ``_reported`` flag — zero hot-path cost.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._reported = False
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if self._reported:
|
||||
return
|
||||
|
||||
# Lazy import — Triton / scattermoe kernels may not be installed.
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
configs = collect_autotune_configs()
|
||||
|
||||
if not configs:
|
||||
if state.global_step >= _MAX_POLL_STEP:
|
||||
LOG.debug(
|
||||
"No autotune data found after %d steps; giving up.",
|
||||
state.global_step,
|
||||
)
|
||||
self._reported = True
|
||||
return
|
||||
|
||||
self._reported = True
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if not telemetry_manager.enabled:
|
||||
return
|
||||
|
||||
properties = {
|
||||
"kernel_count": len(configs),
|
||||
"kernels": configs,
|
||||
}
|
||||
properties.update(_get_gpu_info())
|
||||
properties.update(_get_smem_capacity())
|
||||
|
||||
telemetry_manager.send_event(
|
||||
event_type="scattermoe-autotune",
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
|
||||
len(configs),
|
||||
)
|
||||
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
114
src/axolotl/integrations/kernels/autotune_collector.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Collect Triton autotune results from scattermoe-lora kernels.
|
||||
|
||||
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
|
||||
decorated kernel objects and returns structured dicts describing the selected
|
||||
configurations. It has **no** telemetry dependency — callers decide what to
|
||||
do with the data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# (human-readable name, attribute on the lora_ops module)
|
||||
_KERNEL_REGISTRY: list[tuple[str, str]] = [
|
||||
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
|
||||
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
|
||||
("group_bwd_lora", "_group_bwd_lora"),
|
||||
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
|
||||
]
|
||||
|
||||
# The autotune key declared on every kernel: key=["M", "N", "K"]
|
||||
_KEY_NAMES: list[str] = ["M", "N", "K"]
|
||||
|
||||
|
||||
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
|
||||
"""Turn the autotune cache key tuple into a labelled dict.
|
||||
|
||||
Triton builds the cache key from the values of the declared ``key``
|
||||
args (``M``, ``N``, ``K``) followed by dtype signature elements.
|
||||
We label the first three and store the rest under ``_extra``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for i, name in enumerate(_KEY_NAMES):
|
||||
if i < len(key_tuple):
|
||||
result[name] = key_tuple[i]
|
||||
if len(key_tuple) > len(_KEY_NAMES):
|
||||
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
|
||||
return result
|
||||
|
||||
|
||||
def _find_lora_ops_module() -> ModuleType | None:
|
||||
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
|
||||
|
||||
The HF ``kernels`` package loads ``scattermoe_lora`` via
|
||||
``import_from_path`` which registers it in ``sys.modules`` under a
|
||||
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
|
||||
import (``from axolotl.integrations.kernels...``) would create a
|
||||
*separate* module instance whose kernel objects have empty
|
||||
``.cache`` dicts because autotuning ran on the runtime copy.
|
||||
|
||||
We search ``sys.modules`` for any module whose name contains
|
||||
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
|
||||
attribute — that is the runtime copy with populated caches.
|
||||
"""
|
||||
for name, module in list(sys.modules.items()):
|
||||
if (
|
||||
module is not None
|
||||
and "lora_ops" in name
|
||||
and hasattr(module, "_scatter2scatter_lora")
|
||||
):
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def collect_autotune_configs() -> list[dict[str, Any]]:
|
||||
"""Read autotune caches from the four scattermoe-lora kernels.
|
||||
|
||||
Returns a (possibly empty) list of dicts, each containing:
|
||||
|
||||
* ``kernel`` – human-readable kernel name
|
||||
* ``key`` – dict with the ``M``/``N``/``K`` problem dimensions
|
||||
* ``config`` – dict with the selected tile sizes, ``num_warps``,
|
||||
and ``num_stages``
|
||||
|
||||
Returns ``[]`` if the kernel module cannot be found or if no
|
||||
autotune cache entries exist yet.
|
||||
"""
|
||||
lora_ops = _find_lora_ops_module()
|
||||
if lora_ops is None:
|
||||
LOG.debug(
|
||||
"lora_ops module not found in sys.modules; skipping autotune collection"
|
||||
)
|
||||
return []
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for friendly_name, attr_name in _KERNEL_REGISTRY:
|
||||
kernel_fn = getattr(lora_ops, attr_name, None)
|
||||
if kernel_fn is None:
|
||||
continue
|
||||
|
||||
cache = getattr(kernel_fn, "cache", None)
|
||||
if not cache:
|
||||
continue
|
||||
|
||||
for key_tuple, config in cache.items():
|
||||
config_dict = dict(config.kwargs)
|
||||
config_dict["num_warps"] = config.num_warps
|
||||
config_dict["num_stages"] = config.num_stages
|
||||
if getattr(config, "num_ctas", None) is not None:
|
||||
config_dict["num_ctas"] = config.num_ctas
|
||||
|
||||
results.append(
|
||||
{
|
||||
"kernel": friendly_name,
|
||||
"key": _parse_key_tuple(key_tuple),
|
||||
"config": config_dict,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
|
||||
"olmoe": "OlmoeSparseMoeBlock",
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"minimax": "MiniMaxSparseMoeBlock",
|
||||
# softmax -> topk routing (with group-based expert selection)
|
||||
"mistral4": "Mistral4MoE",
|
||||
# sigmoid -> topk routing (with group-based expert selection)
|
||||
"glm_moe_dsa": "GlmMoeDsaMoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
|
||||
@@ -195,6 +195,36 @@ def _estimate_smem_usage(
|
||||
_SMEM_SLACK = 10_000
|
||||
|
||||
|
||||
def _estimate_register_pressure(
|
||||
num_warps: int,
|
||||
*tile_sizes: tuple[int, int],
|
||||
) -> float:
|
||||
"""Rough estimate of per-thread register footprint from live tile sizes.
|
||||
|
||||
This is a heuristic, NOT an accurate register count. Triton uses tensor
|
||||
core MMA fragments that pack multiple elements per register, and can spill
|
||||
to local memory when the hardware limit (255 regs/thread) is exceeded.
|
||||
|
||||
The estimate is used to prune only truly extreme configs that would cause
|
||||
excessive spilling or compilation failures. The threshold is set high
|
||||
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
|
||||
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
|
||||
(est ~520) work fine in practice via spilling.
|
||||
|
||||
Returns estimated registers per thread.
|
||||
"""
|
||||
# Each thread in a warp holds ~1/32 of the tile elements
|
||||
tile_regs = sum(r * c for r, c in tile_sizes) / 32
|
||||
scalar_overhead = 40
|
||||
return tile_regs + scalar_overhead
|
||||
|
||||
|
||||
# Soft limit for register pressure pruning. Only prune configs with extreme
|
||||
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
|
||||
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
|
||||
_MAX_REGS_SOFT_LIMIT = 1024
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Forward Kernel: scatter2scatter with fused LoRA
|
||||
# =============================================================================
|
||||
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
|
||||
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
|
||||
) # [BLOCK_N, BLOCK_R]
|
||||
|
||||
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
|
||||
# Both operands must match; cast to float32 (accumulator type) for precision.
|
||||
b_f32 = b.to(tl.float32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
|
||||
b_inp = b.to(INPUT_DTYPE)
|
||||
|
||||
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
|
||||
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
|
||||
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_out
|
||||
return acc
|
||||
@@ -327,20 +356,21 @@ def _compute_expert_block_lora(
|
||||
def _scatter2scatter_lora_configs():
|
||||
"""Generate forward kernel autotune configs.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
Search space includes BLOCK_M to allow trading token-tile size for
|
||||
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
|
||||
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
|
||||
(4× fewer inner-loop iterations).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_K: {32, 64, 128}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
|
||||
scatter2scatter pattern).
|
||||
"""
|
||||
configs = []
|
||||
for block_n, block_k, warps, stages in product(
|
||||
for block_m, block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[4, 8], # num_warps
|
||||
@@ -348,7 +378,7 @@ def _scatter2scatter_lora_configs():
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
|
||||
|
||||
|
||||
def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
"""Prune forward configs based on SMEM capacity.
|
||||
"""Prune forward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The forward kernel inner loop loads three tiles per pipeline stage:
|
||||
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
|
||||
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_r * block_k * 2
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
|
||||
smem_lora_epilogue = block_n * block_r * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
|
||||
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_n), # acc
|
||||
(block_m, block_r), # xa_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_k, block_n), # w tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _scatter2scatter_lora_split(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
k: int,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
b: Optional[torch.Tensor] = None,
|
||||
x_grouped: bool = False,
|
||||
y_grouped: bool = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||
|
||||
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||
|
||||
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||
scatter2scatter,
|
||||
)
|
||||
|
||||
E = W.size(0)
|
||||
R = lora_A.size(0) // E
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X,
|
||||
W=W,
|
||||
b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X,
|
||||
W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA,
|
||||
W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
x_grouped=True,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
output.add_(Y_lora, alpha=scaling)
|
||||
return output
|
||||
|
||||
|
||||
# Threshold for switching from fused to split LoRA forward.
|
||||
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||
# loads dominate in the fused kernel's inner loop).
|
||||
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||
|
||||
|
||||
def scatter2scatter_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
|
||||
Automatically selects between:
|
||||
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||
Best for many small experts (E>=64, small K*N).
|
||||
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||
|
||||
Args:
|
||||
X: Input [M, K] or [M*k, K] if x_grouped
|
||||
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
|
||||
Returns:
|
||||
Y: Output [M*k, N]
|
||||
"""
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
E = W.size(0)
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
|
||||
return _scatter2scatter_lora_split(
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Pad R to power of 2 for Triton tile size
|
||||
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
|
||||
b_ptr,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
|
||||
K=K,
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
|
||||
+ (A_expert_offset + R_block)[:, None] * stride_ar
|
||||
+ K_block[None, :] * stride_ak
|
||||
)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
|
||||
|
||||
# Cast to float32 for precision
|
||||
a_f32 = a_e.to(tl.float32)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
|
||||
INPUT_DTYPE
|
||||
)
|
||||
|
||||
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
|
||||
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
|
||||
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_dx
|
||||
return acc
|
||||
@@ -779,17 +939,18 @@ def _scatter2scatter_lora_dX_configs():
|
||||
The inner loop is over N (not K as in forward). The output dimension is K.
|
||||
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
BLOCK_M is now autotunable (was fixed at 128).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128} (token tile)
|
||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
"""
|
||||
configs = []
|
||||
for block_k, block_n, warps, stages in product(
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||
[4, 8], # num_warps
|
||||
@@ -797,7 +958,7 @@ def _scatter2scatter_lora_dX_configs():
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
|
||||
|
||||
|
||||
def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward dX configs based on SMEM capacity.
|
||||
"""Prune backward dX configs based on SMEM capacity and register pressure.
|
||||
|
||||
The dX kernel inner loop loads three tiles per pipeline stage:
|
||||
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
|
||||
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_n * block_r * 2
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
|
||||
smem_lora_epilogue = block_r * block_k * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
|
||||
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_k), # acc
|
||||
(block_m, block_r), # dy_b_acc
|
||||
(block_m, block_n), # dy tile
|
||||
(block_n, block_k), # wt tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_r, block_k), # a tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_M=BLOCK_M,
|
||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
|
||||
|
||||
|
||||
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward configs based on SMEM capacity.
|
||||
"""Prune backward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
|
||||
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
|
||||
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
|
||||
smem_lora = (block_r * block_k + block_n * block_r) * 2
|
||||
smem = smem_base + smem_lora
|
||||
|
||||
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
|
||||
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_k), # dA_acc
|
||||
(block_n, block_r), # dB_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_m, block_n), # dy tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_m, block_r), # xa intermediate
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
|
||||
)
|
||||
|
||||
|
||||
def _group_bwd_split_configs():
|
||||
"""Autotune configs for split dA/dB kernels."""
|
||||
configs = []
|
||||
for block_m, block_dim, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M (token tile)
|
||||
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def _prune_split_configs(configs, named_args, **kwargs):
|
||||
"""Prune split kernel configs based on SMEM capacity and register pressure."""
|
||||
smem_cap = _get_smem_capacity()
|
||||
block_r = named_args.get("BLOCK_R", 64)
|
||||
|
||||
# Fixed inner tile for reduction dimension
|
||||
BLOCK_INNER = 64
|
||||
|
||||
pruned = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_dim = config.kwargs["BLOCK_DIM"]
|
||||
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
|
||||
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
|
||||
# LoRA weights held in registers: [INNER, R] or [R, DIM]
|
||||
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
|
||||
|
||||
# Register pressure check
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_dim), # acc
|
||||
(block_m, BLOCK_INNER), # input tile
|
||||
(block_m, block_dim), # other tile
|
||||
(block_r, BLOCK_INNER), # lora weight
|
||||
)
|
||||
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||
continue
|
||||
|
||||
if smem <= smem_cap - _SMEM_SLACK:
|
||||
pruned.append(config)
|
||||
|
||||
if pruned:
|
||||
return pruned
|
||||
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
|
||||
return [configs[0]]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_group_bwd_split_configs(),
|
||||
key=["M", "K", "N"],
|
||||
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_DIM_MASK": lambda args: (
|
||||
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||
if args["COMPUTE_DA"]
|
||||
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||
),
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _group_bwd_lora_split(
|
||||
# Data tensors (DY and X are always present)
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyn,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||
LW_ptr,
|
||||
stride_lw0,
|
||||
stride_lw1,
|
||||
# Output gradient tensor (dA or dB)
|
||||
OUT_ptr,
|
||||
stride_out0,
|
||||
stride_out1,
|
||||
# Expert offsets
|
||||
expert_offsets_ptr,
|
||||
# Dimensions
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
ACTUAL_R: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr,
|
||||
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||
scaling,
|
||||
# Mode flag
|
||||
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||
# Tile sizes
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_DIM_MASK: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Unified split kernel for LoRA gradient computation.
|
||||
|
||||
When COMPUTE_DA=True:
|
||||
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
|
||||
Grid: (E, cdiv(K, BLOCK_DIM))
|
||||
- outer_ptr/stride = X (read [M, K_block])
|
||||
- inner reduction over N using DY and B
|
||||
- output shape [BLOCK_R, BLOCK_DIM]
|
||||
|
||||
When COMPUTE_DA=False:
|
||||
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
|
||||
Grid: (E, cdiv(N, BLOCK_DIM))
|
||||
- outer_ptr/stride = DY (read [M, N_block])
|
||||
- inner reduction over K using X and A
|
||||
- output shape [BLOCK_DIM, BLOCK_R]
|
||||
|
||||
No atomic adds — each (E, dim_block) pair is written by exactly one block.
|
||||
"""
|
||||
E_idx = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
# Output dimension tile (K for dA, N for dB)
|
||||
if COMPUTE_DA:
|
||||
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
|
||||
else:
|
||||
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
|
||||
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
dim_mask = dim_block < OUT_DIM
|
||||
R_block = tl.arange(0, BLOCK_R)
|
||||
R_mask = R_block < ACTUAL_R
|
||||
lora_offset = E_idx * ACTUAL_R
|
||||
|
||||
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
|
||||
if COMPUTE_DA:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ (lora_offset + R_block)[:, None] * stride_out0
|
||||
+ dim_block[None, :] * stride_out1
|
||||
)
|
||||
out_mask = R_mask[:, None] & dim_mask[None, :]
|
||||
else:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ dim_block[:, None] * stride_out0
|
||||
+ (lora_offset + R_block)[None, :] * stride_out1
|
||||
)
|
||||
out_mask = dim_mask[:, None] & R_mask[None, :]
|
||||
|
||||
if num_tokens > 0:
|
||||
M_block = tl.arange(0, BLOCK_M)
|
||||
INPUT_DTYPE = X_ptr.dtype.element_ty
|
||||
BLOCK_INNER: tl.constexpr = 64
|
||||
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
|
||||
|
||||
if COMPUTE_DA:
|
||||
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
|
||||
else:
|
||||
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
|
||||
|
||||
M_iters = tl.cdiv(num_tokens, BLOCK_M)
|
||||
for i in range(M_iters):
|
||||
M_idx = start_idx + i * BLOCK_M + M_block
|
||||
M_mask = M_idx < end_idx
|
||||
|
||||
if COMPUTE_DA:
|
||||
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||
outer = tl.load(
|
||||
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < N
|
||||
|
||||
dy_tile = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ inn_off[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ inn_off[:, None] * stride_lw0
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||
acc += tl.dot(
|
||||
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
|
||||
)
|
||||
else:
|
||||
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||
outer = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ dim_block[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < K
|
||||
|
||||
x_tile = tl.load(
|
||||
X_ptr
|
||||
+ M_idx[:, None] * stride_xm
|
||||
+ inn_off[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw0
|
||||
+ inn_off[:, None] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||
acc += tl.dot(
|
||||
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
|
||||
)
|
||||
|
||||
tl.store(
|
||||
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
|
||||
)
|
||||
else:
|
||||
# Zero out this expert's slice — needed because output uses empty_like
|
||||
if COMPUTE_DA:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
|
||||
|
||||
def group_bwd_lora(
|
||||
DY: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
|
||||
"""
|
||||
Compute LoRA gradients for A and B on expert-grouped data.
|
||||
|
||||
Uses split dA/dB kernels that eliminate atomic adds by giving each
|
||||
(expert, output_block) pair its own thread block.
|
||||
|
||||
Args:
|
||||
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
||||
X: Input [M_total, K] (grouped by expert)
|
||||
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
|
||||
K = X.size(1)
|
||||
N = DY.size(1)
|
||||
|
||||
# Zero-init for atomic accumulation
|
||||
dA = torch.zeros_like(lora_A)
|
||||
dB = torch.zeros_like(lora_B)
|
||||
# No zero-init needed: the split kernels write zeros for experts with
|
||||
# zero routed tokens directly in the kernel (else branch).
|
||||
dA = torch.empty_like(lora_A)
|
||||
dB = torch.empty_like(lora_B)
|
||||
|
||||
BLOCK_R = _block_r_for_rank(R)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
E * triton.cdiv(K, META["BLOCK_K"]),
|
||||
triton.cdiv(N, META["BLOCK_N"]),
|
||||
)
|
||||
def grid_dA(META):
|
||||
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora[grid](
|
||||
_group_bwd_lora_split[grid_dA](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
expert_offsets,
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=N,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=True,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
def grid_dB(META):
|
||||
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora_split[grid_dB](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
dB,
|
||||
dB.stride(0),
|
||||
dB.stride(1),
|
||||
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R, # True LoRA rank
|
||||
BLOCK_R=BLOCK_R, # Padded tile size
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=K,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=False,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Routing helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
|
||||
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return routing_weights, selected_experts, top_k, num_experts
|
||||
|
||||
|
||||
def _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
):
|
||||
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
|
||||
|
||||
Supports:
|
||||
- ``e_score_correction_bias`` on gate or moe_block
|
||||
- Group-based expert selection when ``n_group > 1``
|
||||
- ``routed_scaling_factor`` applied to final weights
|
||||
- Final weights gathered from original sigmoid probs (not bias-corrected)
|
||||
|
||||
Returns:
|
||||
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
|
||||
"""
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float())
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
|
||||
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is not None:
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
else:
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group-based selection: pick top groups, mask the rest
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, num_experts // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [T, n_group]
|
||||
topk_group = getattr(moe_block, "topk_group", n_group)
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(-1, n_group, num_experts // n_group)
|
||||
.reshape(-1, num_experts)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
# Final topk from (possibly masked) scores
|
||||
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
|
||||
|
||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Optional renormalization + scaling
|
||||
if getattr(moe_block, "norm_topk_prob", True):
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_indices, top_k, num_experts
|
||||
|
||||
|
||||
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
||||
"""Dispatch to the correct routing strategy based on block attributes.
|
||||
|
||||
Detects sigmoid routing by the presence of ``e_score_correction_bias``
|
||||
on either the gate or the moe_block.
|
||||
"""
|
||||
has_sigmoid = (
|
||||
getattr(base_gate, "e_score_correction_bias", None) is not None
|
||||
or getattr(moe_block, "e_score_correction_bias", None) is not None
|
||||
)
|
||||
if has_sigmoid:
|
||||
return _sigmoid_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
return _softmax_topk_route(
|
||||
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared expert helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _compute_shared_expert(moe_block, hidden_states_flat):
|
||||
"""Compute shared expert output if the block has one.
|
||||
|
||||
Handles singular (qwen2_moe: ``shared_expert``), plural
|
||||
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
|
||||
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
|
||||
|
||||
peft wraps individual linear layers inside the shared expert with
|
||||
standard LoRA — calling forward() handles this transparently.
|
||||
"""
|
||||
shared_expert = (
|
||||
getattr(moe_block, "shared_expert", None)
|
||||
or getattr(moe_block, "shared_experts", None)
|
||||
or getattr(moe_block, "shared_mlp", None)
|
||||
)
|
||||
if shared_expert is None:
|
||||
return None
|
||||
|
||||
shared_expert_output = shared_expert(hidden_states_flat)
|
||||
|
||||
# Optional sigmoid gate (Qwen2MoE pattern).
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
|
||||
if shared_expert_gate is not None:
|
||||
shared_expert_output = (
|
||||
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
|
||||
)
|
||||
|
||||
return shared_expert_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Layer classes
|
||||
# =============================================================================
|
||||
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
|
||||
|
||||
class HFScatterMoEGatedMLP(nn.Module):
|
||||
"""
|
||||
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
|
||||
ScatterMoE-accelerated forward pass for HF MoEs.
|
||||
|
||||
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
|
||||
method replaces the original ``OlmoeSparseMoeBlock.forward``.
|
||||
method replaces the original SparseMoeBlock.forward.
|
||||
|
||||
Supports both full-parameter training and LoRA fine-tuning:
|
||||
Supports:
|
||||
|
||||
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
|
||||
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
|
||||
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
self: The MoeSparseMoeBlock module containing:
|
||||
- self.gate: Router (or peft ParamWrapper wrapping it)
|
||||
- self.experts: Experts module (or peft ParamWrapper chain)
|
||||
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
|
||||
- self.shared_expert(s): Optional shared expert
|
||||
- self.shared_expert_gate: Optional shared expert gate
|
||||
layer_input: Input tensor [batch_size, seq_len, hidden_size]
|
||||
|
||||
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# ====================================================================
|
||||
# peft wraps individual linear layers inside shared_expert with
|
||||
# standard LoRA — calling forward() handles this transparently.
|
||||
if hasattr(self, "shared_expert") and self.shared_expert is not None:
|
||||
shared_expert_output = self.shared_expert(hidden_states_flat)
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states_flat)
|
||||
)
|
||||
shared_expert_output = shared_expert_output * shared_expert_gate_output
|
||||
else:
|
||||
shared_expert_output = None
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Router Computation (with optional gate LoRA)
|
||||
# ====================================================================
|
||||
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
|
||||
router_logits = F.linear(hidden_states_flat, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states_flat, gate_lora_delta
|
||||
)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if base_gate.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights, selected_experts, top_k, num_experts = _route(
|
||||
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
|
||||
)
|
||||
routing_weights = routing_weights.to(hidden_states_flat.dtype)
|
||||
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
@@ -356,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
# When experts are BnB-quantized (quantize_moe_experts), dequantize
|
||||
# only the active experts instead of all E. This saves ~97% memory
|
||||
# for the transient dequant buffer when few experts are active.
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
|
||||
get_active_experts,
|
||||
remap_expert_indices,
|
||||
selective_expert_weights,
|
||||
selective_lora_weights,
|
||||
)
|
||||
|
||||
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
|
||||
remapped_expert_idxs, compact_offsets = remap_expert_indices(
|
||||
sorted_expert_idxs,
|
||||
expert_offsets,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup_A, gup_B = selective_lora_weights(
|
||||
gup_A,
|
||||
gup_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
gup_lora = (gup_A, gup_B, gup_scaling)
|
||||
|
||||
# Use remapped indices for ScatterMoE kernels
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# ====================================================================
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
lora_A=gup_A,
|
||||
lora_B=gup_B,
|
||||
scaling=gup_scaling,
|
||||
@@ -383,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
@@ -396,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
# ====================================================================
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
if use_selective:
|
||||
down_W = selective_expert_weights(
|
||||
experts,
|
||||
"down_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, inter, hidden]
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
down_A, down_B = selective_lora_weights(
|
||||
down_A,
|
||||
down_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
down_lora = (down_A, down_B, down_scaling)
|
||||
|
||||
sei_down = remapped_expert_idxs
|
||||
eo_down = compact_offsets
|
||||
else:
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
sei_down = sorted_expert_idxs
|
||||
eo_down = expert_offsets
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
@@ -404,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
lora_A=down_A,
|
||||
lora_B=down_B,
|
||||
scaling=down_scaling,
|
||||
@@ -421,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Selective Expert Dequantization
|
||||
===============================
|
||||
|
||||
Instead of dequantizing all E expert weight matrices at once (which creates
|
||||
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
|
||||
are actually routed to by the current batch's top-k selection.
|
||||
|
||||
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
|
||||
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
|
||||
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
|
||||
- Savings: ~97% memory reduction per layer
|
||||
|
||||
This module provides format-agnostic selective weight extraction:
|
||||
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
|
||||
- bf16/fp32: direct indexing (no dequant needed)
|
||||
- FP8: slice + cast
|
||||
|
||||
The ScatterMoE kernel itself doesn't change — we remap expert indices
|
||||
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||
weight tensor.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
|
||||
"""Get sorted unique expert indices from the routing output.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
active: Sorted unique expert indices [num_active]
|
||||
"""
|
||||
return torch.unique(sorted_expert_idxs)
|
||||
|
||||
|
||||
def remap_expert_indices(
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Remap global expert indices to compact indices.
|
||||
|
||||
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
|
||||
sort order. Also compacts expert_offsets to only active experts.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: [T*k] expert ids in sorted order
|
||||
expert_offsets: [E] cumulative token counts (original)
|
||||
active_experts: [num_active] sorted unique expert ids
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
remapped_idxs: [T*k] expert ids in [0..num_active-1]
|
||||
compact_offsets: [num_active] cumulative token counts
|
||||
"""
|
||||
# Build remap table: global_id -> compact_id
|
||||
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
|
||||
remap[active_experts] = torch.arange(
|
||||
len(active_experts), device=sorted_expert_idxs.device
|
||||
)
|
||||
|
||||
remapped_idxs = remap[sorted_expert_idxs]
|
||||
|
||||
# Compact the expert_offsets: only keep active experts' cumulative counts
|
||||
compact_offsets = expert_offsets[active_experts]
|
||||
|
||||
return remapped_idxs, compact_offsets
|
||||
|
||||
|
||||
def _selective_dequant_bnb4(
|
||||
raw_param: torch.Tensor,
|
||||
quant_state,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||
|
||||
The raw parameter is a flattened 4-bit packed tensor. Each expert's
|
||||
data is contiguous (stored in expert-major order), so we can gather
|
||||
the packed data and absmax blocks for active experts, then dequantize
|
||||
as one contiguous block.
|
||||
|
||||
Args:
|
||||
raw_param: Flattened uint8 tensor of packed 4-bit weights
|
||||
quant_state: BnB QuantState with absmax, blocksize, code, etc.
|
||||
active_experts: [num_active] expert indices to dequantize
|
||||
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2] in original dtype
|
||||
"""
|
||||
import bitsandbytes.functional as F # noqa: N812
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
|
||||
blocks_per_expert = expert_numel // quant_state.blocksize
|
||||
num_active = len(active_experts)
|
||||
|
||||
if blocks_per_expert == 0:
|
||||
# Expert is smaller than one quantization block — blocks span across
|
||||
# expert boundaries, so per-expert slicing isn't possible.
|
||||
# Fallback: full dequantize + index.
|
||||
full = F.dequantize_4bit(raw_param, quant_state)
|
||||
E_total = full.numel() // expert_numel
|
||||
return full.reshape(E_total, *expert_shape)[active_experts]
|
||||
|
||||
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
|
||||
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
|
||||
selective_dequant_nf4_triton,
|
||||
)
|
||||
|
||||
# Handle nested (double) quantization: dequantize absmax first
|
||||
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
|
||||
if quant_state.nested:
|
||||
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
else:
|
||||
absmax = quant_state.absmax
|
||||
|
||||
return selective_dequant_nf4_triton(
|
||||
packed_data=raw_param,
|
||||
absmax=absmax,
|
||||
active_experts=active_experts,
|
||||
expert_shape=expert_shape,
|
||||
blocksize=quant_state.blocksize,
|
||||
dtype=quant_state.dtype,
|
||||
codebook=quant_state.code,
|
||||
)
|
||||
|
||||
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
|
||||
raw_flat = raw_param.reshape(-1)
|
||||
|
||||
offsets_qt = (
|
||||
active_experts.long()[:, None] * packed_per_expert
|
||||
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
qt_gathered = raw_flat[offsets_qt]
|
||||
|
||||
offsets_abs = (
|
||||
active_experts.long()[:, None] * blocks_per_expert
|
||||
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
if quant_state.nested:
|
||||
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
full_absmax += quant_state.offset
|
||||
if full_absmax.dtype != torch.float32:
|
||||
full_absmax = full_absmax.float()
|
||||
absmax_gathered = full_absmax[offsets_abs]
|
||||
else:
|
||||
absmax_gathered = quant_state.absmax[offsets_abs]
|
||||
|
||||
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
|
||||
|
||||
gathered_qs = QuantState(
|
||||
absmax=absmax_gathered,
|
||||
shape=torch.Size([num_active * expert_numel]),
|
||||
blocksize=quant_state.blocksize,
|
||||
quant_type=quant_state.quant_type,
|
||||
code=quant_state.code,
|
||||
dtype=quant_state.dtype,
|
||||
)
|
||||
|
||||
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
|
||||
return deq.reshape(num_active, *expert_shape)
|
||||
|
||||
|
||||
def _selective_index_dense(
|
||||
param: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Select experts from a dense (bf16/fp32) weight tensor.
|
||||
|
||||
Simple indexing — no dequantization needed.
|
||||
"""
|
||||
return param[active_experts]
|
||||
|
||||
|
||||
def selective_expert_weights(
|
||||
experts_module: nn.Module,
|
||||
param_name: str,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract and dequantize only the active experts' weights.
|
||||
|
||||
Format-agnostic: dispatches based on whether the parameter is
|
||||
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
|
||||
|
||||
Args:
|
||||
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
|
||||
param_name: "gate_up_proj" or "down_proj"
|
||||
active_experts: [num_active] sorted unique expert indices
|
||||
|
||||
Returns:
|
||||
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
|
||||
"""
|
||||
# Check if the parameter is BnB-quantized via parametrize
|
||||
if (
|
||||
hasattr(experts_module, "parametrizations")
|
||||
and param_name in experts_module.parametrizations
|
||||
):
|
||||
param_list = experts_module.parametrizations[param_name]
|
||||
parametrization = param_list[0]
|
||||
|
||||
# BnB 4-bit parametrization
|
||||
if hasattr(parametrization, "quant_state"):
|
||||
# The raw quantized data is on the ParametrizationList, not the
|
||||
# individual Bnb4bitParametrization module
|
||||
raw_param = param_list.original
|
||||
qs = parametrization.quant_state
|
||||
# qs.shape is the original tensor shape before flattening.
|
||||
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
|
||||
orig_shape = qs.shape
|
||||
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
|
||||
expert_shape = (orig_shape[1], orig_shape[2])
|
||||
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
|
||||
# Flattened — need to infer from module attributes
|
||||
E_total = getattr(experts_module, "num_experts", None)
|
||||
if E_total is None:
|
||||
E_total = int(active_experts.max().item()) + 1
|
||||
expert_numel = orig_shape[0] // E_total
|
||||
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
|
||||
experts_module, "intermediate_dim", None
|
||||
)
|
||||
if d2 and expert_numel % d2 == 0:
|
||||
expert_shape = (expert_numel // d2, d2)
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
|
||||
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
|
||||
|
||||
# Dense parameter (bf16/fp32) — direct indexing
|
||||
param = getattr(experts_module, param_name)
|
||||
if param.dim() == 3:
|
||||
return param[active_experts]
|
||||
|
||||
# Fallback: full access
|
||||
return param
|
||||
|
||||
|
||||
def selective_lora_weights(
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Select LoRA A and B weights for only the active experts.
|
||||
|
||||
LoRA layout (scattermoe format):
|
||||
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
|
||||
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
|
||||
|
||||
Returns compact:
|
||||
A: [r*num_active, K]
|
||||
B: [N, r*num_active]
|
||||
"""
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
|
||||
row_idx = (
|
||||
active_experts.long()[:, None] * R
|
||||
+ torch.arange(R, device=lora_A.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
compact_A = lora_A[row_idx] # [r*num_active, K]
|
||||
compact_B = lora_B[:, row_idx] # [N, r*num_active]
|
||||
|
||||
return compact_A, compact_B
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Triton kernel for fused selective expert gather + NF4 dequantization.
|
||||
|
||||
Instead of:
|
||||
1. Gather packed uint8 data for active experts (memory copy)
|
||||
2. Gather absmax for active experts (memory copy)
|
||||
3. Call BnB dequantize_4bit CUDA kernel
|
||||
|
||||
This kernel does all three in one pass:
|
||||
- Reads packed NF4 bytes from expert-strided positions
|
||||
- Looks up the NF4 codebook
|
||||
- Multiplies by the per-block absmax
|
||||
- Writes bf16 output directly
|
||||
|
||||
This eliminates the intermediate gather buffer entirely.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# NF4 codebook (16 values, precomputed by BnB)
|
||||
# These are the normalized float4 reconstruction values
|
||||
NF4_CODEBOOK = [
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_dequant_nf4_kernel(
|
||||
# Input: packed NF4 data (flattened, expert-major order)
|
||||
packed_ptr,
|
||||
# Input: absmax values (flattened, expert-major order)
|
||||
absmax_ptr,
|
||||
# Input: active expert indices
|
||||
active_experts_ptr,
|
||||
# Input: NF4 codebook (16 float values)
|
||||
codebook_ptr,
|
||||
# Output: dequantized bf16 weights [num_active, expert_numel]
|
||||
out_ptr,
|
||||
stride_out_e, # stride for expert dim in output
|
||||
# Dimensions
|
||||
num_active,
|
||||
packed_per_expert, # expert_numel // 2
|
||||
blocks_per_expert, # expert_numel // blocksize
|
||||
blocksize: tl.constexpr,
|
||||
# Tile size
|
||||
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||
):
|
||||
"""
|
||||
Each program processes BLOCK_SIZE elements from one expert.
|
||||
|
||||
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
For each output element:
|
||||
1. Compute which byte in packed data contains this element
|
||||
2. Extract the 4-bit nibble (high or low)
|
||||
3. Look up in NF4 codebook
|
||||
4. Scale by absmax for this block
|
||||
"""
|
||||
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
|
||||
block_id = tl.program_id(1) # which element block
|
||||
|
||||
# Load the global expert index
|
||||
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
|
||||
|
||||
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
|
||||
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = elem_offset < expert_numel
|
||||
|
||||
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
|
||||
byte_idx = elem_offset // 2
|
||||
is_high = (elem_offset % 2) == 1
|
||||
|
||||
# Read packed bytes from the global expert's region
|
||||
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
|
||||
tl.int32
|
||||
)
|
||||
|
||||
# Extract 4-bit nibble
|
||||
# BnB packing: high nibble = even element, low nibble = odd element
|
||||
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
|
||||
|
||||
# NF4 codebook lookup
|
||||
# Load all 16 codebook values (small, fits in registers)
|
||||
# Use gather from codebook pointer
|
||||
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
|
||||
|
||||
# Load absmax for this element's quantization block
|
||||
block_idx = elem_offset // blocksize
|
||||
absmax_global_offset = expert_global * blocks_per_expert + block_idx
|
||||
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
|
||||
|
||||
# Dequantize: value = codebook[nibble] * absmax
|
||||
result = code_val * absmax_val
|
||||
|
||||
# Store to output
|
||||
out_offset = expert_local_idx * stride_out_e + elem_offset
|
||||
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_dequant_nf4_triton(
|
||||
packed_data: torch.Tensor,
|
||||
absmax: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
blocksize: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
codebook: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Fused selective gather + NF4 dequantization via Triton kernel.
|
||||
|
||||
Args:
|
||||
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
|
||||
absmax: Per-block scaling factors [total_blocks]
|
||||
active_experts: Sorted indices of experts to dequantize [num_active]
|
||||
expert_shape: (dim1, dim2) per expert
|
||||
blocksize: Quantization block size
|
||||
dtype: Output dtype (default bf16)
|
||||
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2]
|
||||
"""
|
||||
num_active = active_experts.shape[0]
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2
|
||||
blocks_per_expert = expert_numel // blocksize
|
||||
|
||||
# Prepare codebook on device
|
||||
if codebook is None:
|
||||
codebook = torch.tensor(
|
||||
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
|
||||
)
|
||||
else:
|
||||
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||
|
||||
# Flatten inputs
|
||||
packed_flat = packed_data.reshape(-1)
|
||||
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||
|
||||
# Output buffer
|
||||
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
|
||||
|
||||
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||
|
||||
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
_selective_dequant_nf4_kernel[grid](
|
||||
packed_flat,
|
||||
absmax_flat,
|
||||
active_experts,
|
||||
codebook,
|
||||
out,
|
||||
out.stride(0),
|
||||
num_active=num_active,
|
||||
packed_per_expert=packed_per_expert,
|
||||
blocks_per_expert=blocks_per_expert,
|
||||
blocksize=blocksize,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out.reshape(num_active, *expert_shape)
|
||||
@@ -61,9 +61,20 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
self._kernelize_model(cfg.model_config_type)
|
||||
self._kernelize_model(moe_model_type)
|
||||
elif cfg.use_sonicmoe:
|
||||
if not importlib.util.find_spec("sonicmoe"):
|
||||
raise RuntimeError(
|
||||
@@ -75,11 +86,9 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(
|
||||
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
cfg.model_config_type,
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
)
|
||||
|
||||
@@ -110,6 +119,16 @@ class KernelsPlugin(BasePlugin):
|
||||
}
|
||||
)
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
callbacks = []
|
||||
if cfg.use_scattermoe:
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
callbacks.append(AutotuneReportCallback())
|
||||
return callbacks
|
||||
|
||||
def _kernelize_model(self, model_type: str):
|
||||
from kernels import replace_kernel_forward_from_hub
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
|
||||
"minimax",
|
||||
):
|
||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("mistral4",):
|
||||
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in (
|
||||
"glm_moe_dsa",
|
||||
"deepseek_v3",
|
||||
@@ -126,6 +129,62 @@ def softmax_topk_routing(
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def sigmoid_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_lora_parameters(
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
QuantState | None,
|
||||
QuantState | torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
@@ -48,9 +48,13 @@ def get_lora_parameters(
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
if quant_state is None and W.dtype == torch.float8_e4m3fn:
|
||||
quant_state = getattr(base_layer, "weight_scale_inv", None)
|
||||
return W, b, quant_state, None, None, None
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
if quant_state is None and W.dtype == torch.float8_e4m3fn:
|
||||
quant_state = getattr(base_layer, "weight_scale_inv", None)
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
@@ -81,7 +85,7 @@ def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
b: torch.Tensor | None,
|
||||
W_quant: QuantState | None,
|
||||
W_quant: QuantState | torch.Tensor | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
s: float | None,
|
||||
@@ -636,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
|
||||
# This is 65x fewer FLOPs than materializing B@A into [out, in]
|
||||
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
@@ -644,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
@@ -652,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
@@ -815,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
|
||||
del W
|
||||
|
||||
A, B = A.to(dtype), B.to(dtype)
|
||||
dX += s * dY @ B @ A
|
||||
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
|
||||
dX.addmm_(torch.mm(dY, B), A, alpha=s)
|
||||
|
||||
# W, b, W_quant, A, B, s
|
||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
"""Dequantization utilities for `bitsandbytes` and FP8 integration."""
|
||||
|
||||
import ctypes
|
||||
|
||||
@@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize_fp8(
|
||||
W: torch.Tensor,
|
||||
scale_inv: torch.Tensor,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
|
||||
|
||||
Args:
|
||||
W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.
|
||||
scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]
|
||||
or per-tensor scalar.
|
||||
dtype: Output dtype (default bf16).
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype.
|
||||
"""
|
||||
W_float = W.to(dtype)
|
||||
if scale_inv.numel() == 1:
|
||||
return W_float * scale_inv.to(dtype)
|
||||
if scale_inv.dim() == 2 and W.dim() == 2:
|
||||
sr, sc = scale_inv.shape
|
||||
br = W.shape[0] // sr
|
||||
bc = W.shape[1] // sc
|
||||
# If dimensions are exactly divisible, use fast reshape path
|
||||
if sr * br == W.shape[0] and sc * bc == W.shape[1]:
|
||||
return (
|
||||
W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)
|
||||
).reshape(W.shape)
|
||||
# Tail-block handling: compute actual block size (ceil division),
|
||||
# tile scale_inv to cover full shape, then crop to W's dimensions
|
||||
br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size
|
||||
bc_ceil = -(-W.shape[1] // sc)
|
||||
scale_expanded = (
|
||||
scale_inv.to(dtype)
|
||||
.repeat_interleave(br_ceil, dim=0)
|
||||
.repeat_interleave(bc_ceil, dim=1)
|
||||
)[: W.shape[0], : W.shape[1]]
|
||||
return W_float * scale_expanded
|
||||
return W_float * scale_inv.to(dtype)
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
quant_state: QuantState | list | torch.Tensor | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -49,6 +90,15 @@ def dequantize(
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# FP8 path: quant_state is actually scale_inv tensor
|
||||
if W.dtype == torch.float8_e4m3fn:
|
||||
scale_inv = quant_state
|
||||
# Caller may pass W.t() (non-contiguous) — dequantize in original
|
||||
# layout then transpose back so the result shape matches the input.
|
||||
if not W.is_contiguous() and W.dim() == 2:
|
||||
return dequantize_fp8(W.t(), scale_inv).t()
|
||||
return dequantize_fp8(W, scale_inv)
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
|
||||
@@ -160,6 +160,18 @@ def load_lora(
|
||||
else:
|
||||
model = get_peft_model(model, lora_config, **model_kwargs)
|
||||
|
||||
# FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
|
||||
# requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
|
||||
if cfg.torch_dtype:
|
||||
_fp8_cast_dtype = cfg.torch_dtype
|
||||
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
_fp8_cast_dtype = torch.bfloat16
|
||||
else:
|
||||
_fp8_cast_dtype = torch.float16
|
||||
for _name, param in model.named_parameters():
|
||||
if param.requires_grad and param.dtype == torch.float8_e4m3fn:
|
||||
param.data = param.data.to(_fp8_cast_dtype)
|
||||
|
||||
if rank == 0:
|
||||
try:
|
||||
model.print_trainable_parameters()
|
||||
|
||||
@@ -215,6 +215,8 @@ class ModelLoader:
|
||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||
if self.cfg.use_kernels:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
if "allow_all_kernels" not in self.model_kwargs:
|
||||
self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
self._check_model_requirements()
|
||||
@@ -503,6 +505,20 @@ class ModelLoader:
|
||||
elif not is_ds_zero3:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
|
||||
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
|
||||
# so the actual VRAM usage is much less than bf16 estimates.
|
||||
# When device_map is "auto", accelerate's infer_auto_device_map computes
|
||||
# the device map at bf16 size (before quantization), causing it to offload
|
||||
# layers to CPU, which BnB then rejects. Force single-GPU placement to
|
||||
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
|
||||
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
|
||||
"auto",
|
||||
None,
|
||||
):
|
||||
self.model_kwargs["device_map"] = {
|
||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||
}
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
@@ -829,8 +845,9 @@ class ModelLoader:
|
||||
def _set_z3_leaf_modules(self):
|
||||
from deepspeed.utils import set_z3_leaf_modules
|
||||
|
||||
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
|
||||
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
|
||||
if moe_type in MOE_ARCH_BLOCK:
|
||||
moe_blocks = MOE_ARCH_BLOCK[moe_type]
|
||||
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
|
||||
set_z3_leaf_modules(
|
||||
self.model,
|
||||
|
||||
@@ -93,11 +93,13 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
self._deactivate_hf_async_load()
|
||||
self._apply_transformers_patches()
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_sageattn_patches()
|
||||
self._apply_flash_attn_4_patches()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_model_specific_patches()
|
||||
@@ -114,6 +116,8 @@ class PatchManager:
|
||||
self._apply_patch_deepspeed_zero3()
|
||||
self._apply_voxtral_patches()
|
||||
self._apply_apertus_patches()
|
||||
self._apply_trl_vllm_patches()
|
||||
self._apply_trl_trainer_utils_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -129,13 +133,6 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches right after model build, before post-load setup."""
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
@@ -227,6 +224,15 @@ class PatchManager:
|
||||
|
||||
patch_sageattn()
|
||||
|
||||
def _apply_flash_attn_4_patches(self):
|
||||
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
||||
if not self.cfg.flash_attention:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
||||
|
||||
patch_flash_attn_4(self.model_config)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
if (
|
||||
@@ -409,17 +415,27 @@ class PatchManager:
|
||||
if self.cfg.load_in_8bit:
|
||||
apply_linear8bitlt_save_patch()
|
||||
|
||||
def _deactivate_hf_async_load(self):
|
||||
"""Load weights synchronously so they can be converted and not OOM."""
|
||||
if self.cfg.load_in_4bit or self.cfg.load_in_8bit:
|
||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
||||
|
||||
def _apply_moe_expert_quantization_patch(self):
|
||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||
if not self.cfg.quantize_moe_experts:
|
||||
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
|
||||
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
|
||||
|
||||
if not self.cfg.quantize_moe_experts and not has_target_params:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_moe_quantization_on_load,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
if self.cfg.quantize_moe_experts:
|
||||
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||
@@ -646,6 +662,50 @@ class PatchManager:
|
||||
|
||||
patch_apertus_xielu_activation()
|
||||
|
||||
def _apply_trl_vllm_patches(self):
|
||||
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
|
||||
if (
|
||||
self.cfg.rl
|
||||
and getattr(self.cfg, "trl", None)
|
||||
and getattr(self.cfg.trl, "use_vllm", False)
|
||||
):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
|
||||
def _apply_trl_trainer_utils_patches(self):
|
||||
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
|
||||
if not self.cfg.rl:
|
||||
return
|
||||
|
||||
try:
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
entropy_from_logits,
|
||||
selective_log_softmax,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
|
||||
return
|
||||
|
||||
import trl.trainer.utils
|
||||
|
||||
# Guard against repeated calls: only stash the original if trl still
|
||||
# points at its own implementation (not our wrapper).
|
||||
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
|
||||
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
|
||||
|
||||
_axolotl_trainer_utils.selective_log_softmax_original = (
|
||||
trl.trainer.utils.selective_log_softmax
|
||||
)
|
||||
trl.trainer.utils.selective_log_softmax = selective_log_softmax
|
||||
|
||||
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
|
||||
trl.trainer.utils.entropy_from_logits = entropy_from_logits
|
||||
|
||||
LOG.info(
|
||||
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
|
||||
)
|
||||
|
||||
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
|
||||
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
|
||||
if self.cfg.scaling_softmax:
|
||||
|
||||
@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
)
|
||||
|
||||
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||
processor_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.tokenizer = tokenizer
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
if (
|
||||
|
||||
@@ -78,30 +78,29 @@ def patch_parallelism_config():
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from transformers import Trainer
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@contextlib.contextmanager
|
||||
def _noop_cp_context(
|
||||
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
|
||||
):
|
||||
yield
|
||||
|
||||
self._cp_context = _noop_cp_context
|
||||
return args
|
||||
|
||||
def _noop_prepare_context_parallel_inputs(self, model, inputs):
|
||||
return contextlib.nullcontext, inputs
|
||||
|
||||
# prevent double CP partition
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
# remove unneeded calculation upstream
|
||||
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs
|
||||
|
||||
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_head_dims(model_config):
|
||||
"""Extract (head_dim, head_dim_v) from a model config.
|
||||
|
||||
Handles composite models (e.g. Qwen3.5 VL) via text_config and
|
||||
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
|
||||
"""
|
||||
cfg = model_config
|
||||
if hasattr(cfg, "text_config"):
|
||||
cfg = cfg.text_config
|
||||
|
||||
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
|
||||
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
|
||||
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
|
||||
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
|
||||
return head_dim, head_dim_v
|
||||
|
||||
# Standard models
|
||||
if hasattr(cfg, "head_dim"):
|
||||
return cfg.head_dim, cfg.head_dim
|
||||
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
return head_dim, head_dim
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def patch_flash_attn_4(model_config=None):
|
||||
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
|
||||
if major not in (9, 10, 11):
|
||||
return
|
||||
|
||||
try:
|
||||
from flash_attn.cute import ( # noqa: F401
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
|
||||
"To enable: pip install flash-attn-4"
|
||||
)
|
||||
return
|
||||
|
||||
# Validate head dimensions against FA4's own constraints
|
||||
head_dim = None
|
||||
if model_config is not None:
|
||||
head_dim, head_dim_v = _get_head_dims(model_config)
|
||||
if head_dim is not None:
|
||||
try:
|
||||
from flash_attn.cute.interface import _validate_head_dims
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"Could not import _validate_head_dims from flash_attn.cute.interface, "
|
||||
"unable to verify head dimension compatibility, falling back to FA2"
|
||||
)
|
||||
return
|
||||
|
||||
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
|
||||
alignment = 8
|
||||
try:
|
||||
_validate_head_dims(head_dim, head_dim_v, major, alignment)
|
||||
except AssertionError as exc:
|
||||
LOG.warning(
|
||||
"Model head dimensions not supported by FA4, "
|
||||
"falling back to FA2: %s",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils as fa_utils
|
||||
|
||||
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||
return
|
||||
|
||||
def _patched_lazy_imports(
|
||||
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||
):
|
||||
return (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
fa_utils._pad_input,
|
||||
fa_utils._unpad_input,
|
||||
)
|
||||
|
||||
_patched_lazy_imports._axolotl_patched = True
|
||||
fa_utils._lazy_imports = _patched_lazy_imports
|
||||
LOG.info(
|
||||
"Flash Attention 4 enabled (head_dim=%s)",
|
||||
head_dim if model_config else "unknown",
|
||||
)
|
||||
@@ -64,15 +64,12 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
LOG.info(
|
||||
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||
flex_attn_compile_kwargs,
|
||||
main_process_only=True,
|
||||
)
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
LOG.info(
|
||||
"Flex attention compiled successfully.", main_process_only=True
|
||||
)
|
||||
LOG.info("Flex attention compiled successfully.")
|
||||
|
||||
self._is_flex_compiled = True
|
||||
|
||||
|
||||
@@ -51,6 +51,29 @@ QKV_PATCHES = [
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
(
|
||||
"""
|
||||
query_states, gate = torch.chunk(
|
||||
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
"""
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states, gate = torch.chunk(
|
||||
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||
)
|
||||
gate = gate.reshape(*input_shape, -1)
|
||||
|
||||
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip("\n"),
|
||||
),
|
||||
]
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
return pretrained_model.language_model.layers
|
||||
if hasattr(pretrained_model, "model"):
|
||||
if hasattr(pretrained_model.model, "language_model"):
|
||||
return pretrained_model.model.language_model.layers
|
||||
return pretrained_model.model.layers
|
||||
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
"""
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||
"""
|
||||
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"mode": "4bit",
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
# Module path → param names in definition order, captured before quantization.
|
||||
# Without this, alphabetical loading order would mismatch merge order.
|
||||
"expert_param_order": {},
|
||||
}
|
||||
|
||||
|
||||
class Bnb8bitParametrization(torch.nn.Module):
|
||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||
"""Dequantizes int8 row-wise quantized data on access."""
|
||||
|
||||
def __init__(self, row_stats: torch.Tensor):
|
||||
super().__init__()
|
||||
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
|
||||
orig_shape = quantized_param.shape
|
||||
if quantized_param.ndim > 2:
|
||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||
"""
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
|
||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||
_moe_load_state["mode"] = mode
|
||||
_moe_load_state["count"] = 0
|
||||
_moe_load_state["expert_param_order"] = {}
|
||||
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
|
||||
)
|
||||
return
|
||||
|
||||
# Record definition order before parametrizations override it
|
||||
# with alphabetical order.
|
||||
if mod_path not in _moe_load_state["expert_param_order"]:
|
||||
_moe_load_state["expert_param_order"][mod_path] = list(
|
||||
mod._parameters.keys()
|
||||
)
|
||||
|
||||
if _moe_load_state["mode"] == "4bit":
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
|
||||
|
||||
1. Expands short suffixes to full module paths for parametrized modules.
|
||||
2. Iterates params in definition order (not alphabetical order) so saved
|
||||
adapters are compatible with standard PEFT, vLLM, etc.
|
||||
"""
|
||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||
return
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
from contextlib import nullcontext
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||
from peft.utils.integrations import init_empty_weights
|
||||
from peft.utils.other import _get_submodules
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
# Patch target_parameters to use full paths for parametrized modules
|
||||
original_targets = list(peft_config.target_parameters)
|
||||
expanded = set(original_targets)
|
||||
|
||||
# Expand short suffixes to full paths for parametrized modules.
|
||||
for module_name, module in model.named_modules():
|
||||
if not hasattr(module, "parametrizations"):
|
||||
continue
|
||||
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
|
||||
) and hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
|
||||
peft_config.target_parameters = sorted(expanded)
|
||||
try:
|
||||
return original_inject(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
)
|
||||
finally:
|
||||
peft_config.target_parameters = original_targets
|
||||
target_names_set = expanded
|
||||
|
||||
def strip_base_layer_from_name(module_name):
|
||||
name = ".base_layer"
|
||||
while name in module_name:
|
||||
prefix, _, suffix = module_name.rpartition(name)
|
||||
module_name = prefix + suffix
|
||||
return module_name
|
||||
|
||||
def create_and_replace_param(module_name, key, param_name):
|
||||
parent, target, target_name = _get_submodules(model, module_name)
|
||||
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||
unwrapped_module = model.get_submodule(unwrapped_module_name)
|
||||
if (
|
||||
isinstance(unwrapped_module, BaseTunerLayer)
|
||||
and unwrapped_module.__class__.__name__ != "ParamWrapper"
|
||||
):
|
||||
raise ValueError(
|
||||
f"Trying to wrap an `nn.Parameter` of layer "
|
||||
f"'{unwrapped_module_name}' of type "
|
||||
f"{type(target).__name__}, which is not a valid target. "
|
||||
f"Make sure that this layer is not also targeted with "
|
||||
f"`target_modules`."
|
||||
)
|
||||
self._check_target_module_compatiblity(peft_config, model, target_name)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
self._create_and_replace(
|
||||
peft_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key=key,
|
||||
parameter_name=param_name.rpartition(".")[-1],
|
||||
)
|
||||
|
||||
# Use definition order (not alphabetical order) for parametrized modules
|
||||
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
|
||||
expert_param_order = _moe_load_state.get("expert_param_order", {})
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if hasattr(module, "parametrizations"):
|
||||
stored_order = expert_param_order.get(module_name)
|
||||
if stored_order is not None:
|
||||
params_iter = [
|
||||
p for p in stored_order if p in module.parametrizations
|
||||
]
|
||||
else:
|
||||
# Fallback for paths that bypass model loading (e.g. unit tests).
|
||||
params_iter = list(module.parametrizations.keys())
|
||||
for param_name in params_iter:
|
||||
key = f"{module_name}.{param_name}"
|
||||
if (key in target_names_set) or any(
|
||||
key.endswith(f".{t}") for t in target_names_set
|
||||
):
|
||||
create_and_replace_param(module_name, key, param_name)
|
||||
self.targeted_parameter_names.append(key)
|
||||
else:
|
||||
unwrapped_module_name = strip_base_layer_from_name(module_name)
|
||||
for param_name, _ in module.named_parameters(recurse=False):
|
||||
key = f"{unwrapped_module_name}.{param_name}"
|
||||
if (key in target_names_set) or any(
|
||||
key.endswith(f".{t}") for t in target_names_set
|
||||
):
|
||||
create_and_replace_param(module_name, key, param_name)
|
||||
self.targeted_parameter_names.append(key)
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
||||
|
||||
@@ -57,7 +57,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"olmo3",
|
||||
"ministral",
|
||||
"ministral3",
|
||||
"mistral4",
|
||||
"afmoe",
|
||||
"nemotron",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -154,7 +154,6 @@ def register_ring_attn_from_device_mesh(
|
||||
LOG.info(
|
||||
f"Enabling ring attention sequence parallelism using DeviceMesh "
|
||||
f"dimension '{context_parallel_dim}'",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
# Extract the sequence parallel submesh
|
||||
|
||||
@@ -85,7 +85,6 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
mlp_cls._tiled_mlp_dist_impl = None
|
||||
LOG.info(
|
||||
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
||||
main_process_only=True,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .utils import entropy_from_logits, selective_log_softmax
|
||||
|
||||
__all__ = ["entropy_from_logits", "selective_log_softmax"]
|
||||
|
||||
245
src/axolotl/monkeypatch/trainer/trl_vllm.py
Normal file
245
src/axolotl/monkeypatch/trainer/trl_vllm.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Monkeypatches for TRL's vLLM integration and trainer utils.
|
||||
|
||||
Adds:
|
||||
- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips)
|
||||
- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation)
|
||||
- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO
|
||||
- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _batch_update_named_params(
|
||||
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
|
||||
):
|
||||
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
|
||||
from transformers import is_torch_xpu_available
|
||||
|
||||
if chunk_size is None:
|
||||
chunks = [params]
|
||||
else:
|
||||
chunks = []
|
||||
current_chunk: list[tuple[str, torch.Tensor]] = []
|
||||
current_elements = 0
|
||||
for name, weights in params:
|
||||
n_elem = weights.numel()
|
||||
if current_chunk and current_elements + n_elem > chunk_size:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
current_elements = 0
|
||||
current_chunk.append((name, weights))
|
||||
current_elements += n_elem
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
for chunk in chunks:
|
||||
param_metadata = [
|
||||
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
|
||||
for name, weights in chunk
|
||||
]
|
||||
url = f"{self.base_url}/batch_update_named_params/"
|
||||
response = self.session.post(url, json={"params": param_metadata})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
for _name, weights in chunk:
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weights, root=self.rank)
|
||||
else:
|
||||
self.communicator.broadcast(weights, src=self.rank)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
|
||||
|
||||
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
|
||||
"""Updates all model params using batch_update_named_params."""
|
||||
params = [(name, param.data) for name, param in model.named_parameters()]
|
||||
self.batch_update_named_params(params, chunk_size=chunk_size)
|
||||
|
||||
|
||||
def _patched_extract_logprobs(all_outputs):
|
||||
"""extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors)."""
|
||||
all_logprobs = []
|
||||
all_token_ids = []
|
||||
|
||||
for outputs in all_outputs:
|
||||
for output in outputs.outputs:
|
||||
if output.logprobs is None:
|
||||
return None, None
|
||||
seq_logprobs = []
|
||||
seq_token_ids = []
|
||||
for lp in output.logprobs:
|
||||
sorted_items = sorted(lp.items(), key=lambda x: x[1].rank)
|
||||
seq_token_ids.append([token_id for token_id, _ in sorted_items])
|
||||
seq_logprobs.append(
|
||||
[
|
||||
0.0 if math.isnan(item.logprob) else item.logprob
|
||||
for _, item in sorted_items
|
||||
]
|
||||
)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
all_token_ids.append(seq_token_ids)
|
||||
|
||||
return all_logprobs, all_token_ids
|
||||
|
||||
|
||||
def _patched_split_tensor_dict(tensor_dict, num_chunks):
|
||||
"""split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch."""
|
||||
first_tensor = next(
|
||||
tensor
|
||||
for tensor in tensor_dict.values()
|
||||
if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0
|
||||
)
|
||||
chunk_size = first_tensor.shape[0] // num_chunks
|
||||
chunks = []
|
||||
for i in range(num_chunks):
|
||||
chunk_dict = {}
|
||||
for key, tensor in tensor_dict.items():
|
||||
if isinstance(tensor, (int, float, bool)):
|
||||
chunk_dict[key] = tensor
|
||||
elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):
|
||||
chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
|
||||
elif tensor is not None and tensor.ndim == 0:
|
||||
chunk_dict[key] = tensor
|
||||
else:
|
||||
chunk_dict[key] = None
|
||||
chunks.append(chunk_dict)
|
||||
return chunks
|
||||
|
||||
|
||||
def _patched_shuffle_sequence_dict(seq_dict):
|
||||
"""shuffle_sequence_dict that handles scalar types (int/float/bool)."""
|
||||
first_seq = next(
|
||||
v
|
||||
for v in seq_dict.values()
|
||||
if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0
|
||||
)
|
||||
perm = torch.randperm(len(first_seq))
|
||||
|
||||
def permute(v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (int, float, bool)):
|
||||
return v
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 0:
|
||||
return v
|
||||
if isinstance(v, torch.Tensor) and v.ndim >= 1:
|
||||
return v[perm]
|
||||
if isinstance(v, list):
|
||||
return [v[i] for i in perm.tolist()]
|
||||
return v
|
||||
|
||||
return {k: permute(v) for k, v in seq_dict.items()}
|
||||
|
||||
|
||||
def _patch_sync_weights_batched(original_init):
|
||||
"""Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size."""
|
||||
|
||||
@wraps(original_init)
|
||||
def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.weight_sync_chunk_size = weight_sync_chunk_size
|
||||
|
||||
return patched_init
|
||||
|
||||
|
||||
def _make_batched_sync_weights(original_sync_weights):
|
||||
"""Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths."""
|
||||
|
||||
@wraps(original_sync_weights)
|
||||
def patched_sync_weights(self):
|
||||
from accelerate.utils import is_peft_model
|
||||
|
||||
# Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps
|
||||
accelerator = self.accelerator
|
||||
model = self.model
|
||||
is_fsdp_enabled = self.is_fsdp_enabled
|
||||
|
||||
deepspeed_plugin = accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
|
||||
is_peft = is_peft_model(model)
|
||||
|
||||
# If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases)
|
||||
if is_peft or is_fsdp_enabled or zero_stage_3:
|
||||
return original_sync_weights(self)
|
||||
|
||||
# Non-PEFT, non-FSDP, non-ZeRO: use batched sync
|
||||
if self.mode == "colocate" and getattr(self, "enable_sleep_mode", False):
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import (
|
||||
empty_cache,
|
||||
)
|
||||
|
||||
empty_cache()
|
||||
self.llm.wake_up(tags=["weights"])
|
||||
|
||||
if self.mode == "server" and accelerator.is_main_process:
|
||||
params = [
|
||||
(self._fix_param_name_to_vllm(name), param.data)
|
||||
for name, param in model.named_parameters()
|
||||
]
|
||||
self.vllm_client.batch_update_named_params(
|
||||
params, chunk_size=getattr(self, "weight_sync_chunk_size", None)
|
||||
)
|
||||
elif self.mode == "colocate":
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
weights = [
|
||||
(self._fix_param_name_to_vllm(name), param.data)
|
||||
for name, param in model.named_parameters()
|
||||
]
|
||||
llm_model.load_weights(weights=weights)
|
||||
|
||||
# Reset cache
|
||||
if self.mode == "server" and accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
elif self.mode == "colocate":
|
||||
self.llm.reset_prefix_cache()
|
||||
|
||||
return patched_sync_weights
|
||||
|
||||
|
||||
def patch_trl_vllm():
|
||||
"""Apply all TRL vLLM monkeypatches."""
|
||||
import trl.generation.vllm_client
|
||||
import trl.generation.vllm_generation
|
||||
import trl.trainer.utils
|
||||
|
||||
VLLMClient = trl.generation.vllm_client.VLLMClient
|
||||
VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration
|
||||
|
||||
# 1. Add batch_update_named_params to VLLMClient
|
||||
if not hasattr(VLLMClient, "batch_update_named_params"):
|
||||
VLLMClient.batch_update_named_params = _batch_update_named_params
|
||||
VLLMClient.update_model_params = _update_model_params
|
||||
LOG.info("Patched VLLMClient with batch_update_named_params")
|
||||
|
||||
# 2. Patch extract_logprobs (NaN→0.0)
|
||||
trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs
|
||||
LOG.info("Patched extract_logprobs with NaN→0.0 fix")
|
||||
|
||||
# 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size
|
||||
VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__)
|
||||
|
||||
# 4. Patch sync_weights for batched non-FSDP/non-ZeRO path
|
||||
VLLMGeneration.sync_weights = _make_batched_sync_weights(
|
||||
VLLMGeneration.sync_weights
|
||||
)
|
||||
LOG.info("Patched VLLMGeneration with batched sync_weights")
|
||||
|
||||
# 5. Patch split_tensor_dict and shuffle_sequence_dict
|
||||
trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict
|
||||
trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict
|
||||
LOG.info("Patched split_tensor_dict and shuffle_sequence_dict for scalar types")
|
||||
429
src/axolotl/monkeypatch/trainer/utils.py
Normal file
429
src/axolotl/monkeypatch/trainer/utils.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# Copyright 2026 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _entropy_online_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
stride_row,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Online entropy: single pass with running max correction."""
|
||||
row = tl.program_id(0)
|
||||
row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row
|
||||
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
running_weighted = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
|
||||
correction = tl.exp(running_max - new_max)
|
||||
running_sum_exp = running_sum_exp * correction
|
||||
running_weighted = running_weighted * correction
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
x = tl.where(mask, x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_weighted += tl.sum(exp_x * x, axis=0)
|
||||
|
||||
running_max = new_max
|
||||
|
||||
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
|
||||
tl.store(output_ptr + row, entropy)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _entropy_online_kernel_strided(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
stride_outer,
|
||||
stride_inner,
|
||||
n_inner,
|
||||
row_offset,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Online entropy for non-contiguous 3D (B, L, V) tensors."""
|
||||
local_row = tl.program_id(0)
|
||||
row = local_row + row_offset
|
||||
outer_idx = row // n_inner
|
||||
inner_idx = row % n_inner
|
||||
off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner
|
||||
row_ptr = logits_ptr + off
|
||||
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
running_weighted = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
|
||||
correction = tl.exp(running_max - new_max)
|
||||
running_sum_exp = running_sum_exp * correction
|
||||
running_weighted = running_weighted * correction
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
x = tl.where(mask, x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_weighted += tl.sum(exp_x * x, axis=0)
|
||||
|
||||
running_max = new_max
|
||||
|
||||
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
|
||||
tl.store(output_ptr + local_row, entropy)
|
||||
|
||||
|
||||
def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
|
||||
"""Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying."""
|
||||
original_shape = logits.shape[:-1]
|
||||
V = logits.shape[-1]
|
||||
N = 1
|
||||
for s in original_shape:
|
||||
N *= s
|
||||
|
||||
if not logits.is_cuda:
|
||||
# CPU fallback: stable entropy via log_softmax
|
||||
logp = F.log_softmax(logits.float(), dim=-1)
|
||||
ent = -(logp.exp() * logp).sum(dim=-1)
|
||||
return ent.to(logits.dtype).reshape(original_shape)
|
||||
|
||||
output = torch.empty(N, device=logits.device, dtype=torch.float32)
|
||||
|
||||
BLOCK_V = 4096
|
||||
MAX_GRID_CONTIG = 8192
|
||||
MAX_GRID_STRIDED = 2048
|
||||
|
||||
# Vocab (last) dim must be contiguous for coalesced loads
|
||||
if logits.stride(-1) != 1:
|
||||
logits = logits.contiguous()
|
||||
|
||||
if logits.is_contiguous():
|
||||
flat_logits = logits.reshape(-1, V)
|
||||
stride = flat_logits.stride(0)
|
||||
for start in range(0, N, MAX_GRID_CONTIG):
|
||||
n_rows = min(MAX_GRID_CONTIG, N - start)
|
||||
_entropy_online_kernel[(n_rows,)](
|
||||
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
|
||||
)
|
||||
elif logits.ndim == 3:
|
||||
stride_outer = logits.stride(0)
|
||||
stride_inner = logits.stride(1)
|
||||
n_inner = logits.shape[1]
|
||||
for start in range(0, N, MAX_GRID_STRIDED):
|
||||
n_rows = min(MAX_GRID_STRIDED, N - start)
|
||||
_entropy_online_kernel_strided[(n_rows,)](
|
||||
logits,
|
||||
output[start],
|
||||
stride_outer,
|
||||
stride_inner,
|
||||
n_inner,
|
||||
start,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
else:
|
||||
logits = logits.contiguous()
|
||||
flat_logits = logits.reshape(-1, V)
|
||||
stride = flat_logits.stride(0)
|
||||
for start in range(0, N, MAX_GRID_CONTIG):
|
||||
n_rows = min(MAX_GRID_CONTIG, N - start)
|
||||
_entropy_online_kernel[(n_rows,)](
|
||||
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
|
||||
)
|
||||
|
||||
return output.to(logits.dtype).reshape(original_shape)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# selective_log_softmax — fused forward + backward Triton kernels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def selective_log_softmax_original(logits, index) -> torch.Tensor:
|
||||
"""Original selective_log_softmax (reference/fallback)."""
|
||||
squeeze = index.ndim == logits.ndim - 1
|
||||
if squeeze:
|
||||
index = index.unsqueeze(-1)
|
||||
|
||||
if logits.dtype in [torch.float32, torch.float64]:
|
||||
selected_logits = torch.gather(logits, dim=-1, index=index)
|
||||
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)
|
||||
else:
|
||||
per_token_logps = []
|
||||
for row_logits, row_labels in zip(logits, index, strict=True):
|
||||
row_logps = F.log_softmax(row_logits, dim=-1)
|
||||
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)
|
||||
per_token_logps.append(row_per_token_logps)
|
||||
per_token_logps = torch.stack(per_token_logps)
|
||||
|
||||
if squeeze:
|
||||
per_token_logps = per_token_logps.squeeze(-1)
|
||||
|
||||
return per_token_logps
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_logsoftmax_fwd_kernel(
|
||||
logits_ptr,
|
||||
index_ptr,
|
||||
output_ptr,
|
||||
logsumexp_ptr,
|
||||
stride_logits_row,
|
||||
stride_index_row,
|
||||
stride_output_row,
|
||||
actual_K,
|
||||
K_BLOCK: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Forward: online logsumexp + gather. Saves logsumexp for backward."""
|
||||
row = tl.program_id(0)
|
||||
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
|
||||
|
||||
# Online logsumexp
|
||||
running_max = tl.full([], float("-inf"), dtype=tl.float32)
|
||||
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
|
||||
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V)
|
||||
mask = offs < V
|
||||
x = tl.load(logits_row_ptr + offs, mask=mask, other=float("-inf")).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
block_max = tl.max(x, axis=0)
|
||||
new_max = tl.maximum(running_max, block_max)
|
||||
running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
|
||||
|
||||
exp_x = tl.exp(x - new_max)
|
||||
exp_x = tl.where(mask, exp_x, 0.0)
|
||||
running_sum_exp += tl.sum(exp_x, axis=0)
|
||||
running_max = new_max
|
||||
|
||||
lse = tl.log(running_sum_exp) + running_max
|
||||
tl.store(logsumexp_ptr + row, lse)
|
||||
|
||||
# Gather and subtract
|
||||
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
|
||||
output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row
|
||||
|
||||
k_offs = tl.arange(0, K_BLOCK)
|
||||
k_mask = k_offs < actual_K
|
||||
indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64)
|
||||
valid_mask = k_mask & (indices >= 0) & (indices < V)
|
||||
safe_indices = tl.where(valid_mask, indices, 0)
|
||||
selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_logsoftmax_bwd_kernel(
|
||||
grad_output_ptr,
|
||||
logits_ptr,
|
||||
index_ptr,
|
||||
logsumexp_ptr,
|
||||
grad_logits_ptr,
|
||||
stride_grad_out_row,
|
||||
stride_logits_row,
|
||||
stride_index_row,
|
||||
stride_grad_logits_row,
|
||||
actual_K,
|
||||
K_BLOCK: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BLOCK_V: tl.constexpr,
|
||||
):
|
||||
"""Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]).
|
||||
|
||||
Single fused pass over V. For each tile, computes the base gradient and adds
|
||||
scatter contributions inline by checking which indices fall in the current tile.
|
||||
No separate scatter pass — no read-after-write issues.
|
||||
"""
|
||||
row = tl.program_id(0)
|
||||
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
|
||||
grad_logits_row_ptr = (
|
||||
grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row
|
||||
)
|
||||
grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row
|
||||
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
|
||||
|
||||
lse = tl.load(logsumexp_ptr + row).to(tl.float32)
|
||||
|
||||
# Load grad_output and indices (K_BLOCK elements, masked)
|
||||
k_offs = tl.arange(0, K_BLOCK)
|
||||
k_mask = k_offs < actual_K
|
||||
grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32)
|
||||
indices = tl.load(
|
||||
index_row_ptr + k_offs, mask=k_mask, other=-1
|
||||
) # -1 = never matches
|
||||
valid_mask = k_mask & (indices >= 0) & (indices < V)
|
||||
grad_out = tl.where(valid_mask, grad_out, 0.0)
|
||||
indices = tl.where(valid_mask, indices, -1)
|
||||
grad_sum = tl.sum(grad_out, axis=0)
|
||||
|
||||
# Fused pass: for each tile, compute -softmax * grad_sum + scatter
|
||||
for v_start in range(0, V, BLOCK_V):
|
||||
offs = v_start + tl.arange(0, BLOCK_V) # [BLOCK_V]
|
||||
mask = offs < V
|
||||
x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32)
|
||||
softmax_j = tl.exp(x - lse)
|
||||
softmax_j = tl.where(mask, softmax_j, 0.0)
|
||||
grad_j = -softmax_j * grad_sum
|
||||
|
||||
# Scatter: check which selected indices fall in this tile
|
||||
# offs: [BLOCK_V], indices: [K_BLOCK]
|
||||
# Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK]
|
||||
match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK]
|
||||
# Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j
|
||||
scatter_contrib = tl.sum(
|
||||
tl.where(match, grad_out[None, :], 0.0), axis=1
|
||||
) # [BLOCK_V]
|
||||
grad_j += scatter_contrib
|
||||
|
||||
tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)
|
||||
|
||||
|
||||
class _SelectiveLogSoftmaxTriton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID):
|
||||
N = flat_logits.shape[0]
|
||||
output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32)
|
||||
logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32)
|
||||
|
||||
for start in range(0, N, MAX_GRID):
|
||||
n_rows = min(MAX_GRID, N - start)
|
||||
_selective_logsoftmax_fwd_kernel[(n_rows,)](
|
||||
flat_logits[start],
|
||||
flat_index[start],
|
||||
output[start],
|
||||
logsumexp[start],
|
||||
flat_logits.stride(0),
|
||||
flat_index.stride(0),
|
||||
output.stride(0),
|
||||
K,
|
||||
K_BLOCK=K_BLOCK,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(flat_logits, flat_index, logsumexp)
|
||||
ctx.K = K
|
||||
ctx.K_BLOCK = K_BLOCK
|
||||
ctx.V = V
|
||||
ctx.BLOCK_V = BLOCK_V
|
||||
ctx.MAX_GRID = MAX_GRID
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
flat_logits, flat_index, logsumexp = ctx.saved_tensors
|
||||
K, K_BLOCK, V, BLOCK_V, MAX_GRID = (
|
||||
ctx.K,
|
||||
ctx.K_BLOCK,
|
||||
ctx.V,
|
||||
ctx.BLOCK_V,
|
||||
ctx.MAX_GRID,
|
||||
)
|
||||
N = flat_logits.shape[0]
|
||||
|
||||
grad_logits = torch.empty_like(flat_logits)
|
||||
|
||||
# grad_output may have K_BLOCK cols; backward kernel reads actual_K
|
||||
grad_output_contig = grad_output.contiguous()
|
||||
|
||||
for start in range(0, N, MAX_GRID):
|
||||
n_rows = min(MAX_GRID, N - start)
|
||||
_selective_logsoftmax_bwd_kernel[(n_rows,)](
|
||||
grad_output_contig[start],
|
||||
flat_logits[start],
|
||||
flat_index[start],
|
||||
logsumexp[start],
|
||||
grad_logits[start],
|
||||
grad_output_contig.stride(0),
|
||||
flat_logits.stride(0),
|
||||
flat_index.stride(0),
|
||||
grad_logits.stride(0),
|
||||
K,
|
||||
K_BLOCK=K_BLOCK,
|
||||
V=V,
|
||||
BLOCK_V=BLOCK_V,
|
||||
)
|
||||
|
||||
# Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index) -> torch.Tensor:
|
||||
"""
|
||||
Fused selective_log_softmax with Triton forward+backward kernels.
|
||||
|
||||
Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index)
|
||||
"""
|
||||
squeeze = index.ndim == logits.ndim - 1
|
||||
if squeeze:
|
||||
index = index.unsqueeze(-1)
|
||||
|
||||
if not logits.is_cuda or logits.dtype == torch.float64:
|
||||
# Triton kernel computes in float32; fall back for float64 and CPU
|
||||
return selective_log_softmax_original(
|
||||
logits, index.squeeze(-1) if squeeze else index
|
||||
)
|
||||
|
||||
V = logits.shape[-1]
|
||||
K = index.shape[-1]
|
||||
original_index_shape = index.shape
|
||||
|
||||
flat_logits = logits.reshape(-1, V).contiguous()
|
||||
flat_index = index.reshape(-1, K).contiguous()
|
||||
|
||||
BLOCK_V = 4096
|
||||
MAX_GRID = 8192
|
||||
K_BLOCK = max(1, triton.next_power_of_2(K))
|
||||
|
||||
output = _SelectiveLogSoftmaxTriton.apply(
|
||||
flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
|
||||
)
|
||||
|
||||
if K_BLOCK != K:
|
||||
output = output[:, :K]
|
||||
|
||||
per_token_logps = output.to(logits.dtype).reshape(original_index_shape)
|
||||
|
||||
if squeeze:
|
||||
per_token_logps = per_token_logps.squeeze(-1)
|
||||
|
||||
return per_token_logps
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
||||
return
|
||||
|
||||
try:
|
||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
||||
return
|
||||
|
||||
if GUARD_PATTERN not in original_source:
|
||||
LOG.warning(
|
||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
||||
"skipping FlashAttention context parallelism patch"
|
||||
)
|
||||
return
|
||||
|
||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
||||
patched_source, _ = detab_code(patched_source)
|
||||
patched_source = patched_source.replace(
|
||||
"def _prepare_context_parallel_inputs(",
|
||||
"def axolotl_prepare_context_parallel_inputs(",
|
||||
1,
|
||||
)
|
||||
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# import symbols referenced in the method so exec can succeed
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Use a separate namespace to capture the exec'd function
|
||||
namespace = {}
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
|
||||
exec(patched_source, namespace)
|
||||
|
||||
# Explicitly get the function from the namespace
|
||||
axolotl_prepare_context_parallel_inputs = namespace[
|
||||
"axolotl_prepare_context_parallel_inputs"
|
||||
]
|
||||
Trainer._original_prepare_context_parallel_inputs = (
|
||||
Trainer._prepare_context_parallel_inputs
|
||||
)
|
||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||
LOG.debug(
|
||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||
)
|
||||
0
src/axolotl/scripts/__init__.py
Normal file
0
src/axolotl/scripts/__init__.py
Normal file
503
src/axolotl/scripts/vllm_serve_lora.py
Normal file
503
src/axolotl/scripts/vllm_serve_lora.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""vLLM serve script with native LoRA adapter support.
|
||||
|
||||
Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,
|
||||
instead of merging adapter weights into the base model before syncing.
|
||||
|
||||
Usage:
|
||||
Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,
|
||||
or ``trl.vllm_lora_sync: true`` to auto-select.
|
||||
|
||||
Benefits over merge-sync:
|
||||
- Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL
|
||||
- vLLM handles LoRA application natively (Punica kernels)
|
||||
- No NCCL communicator needed for weight sync
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any
|
||||
|
||||
from trl.scripts.vllm_serve import (
|
||||
ScriptArguments,
|
||||
chunk_list,
|
||||
extract_logprobs,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAScriptArguments(ScriptArguments):
|
||||
"""Extended script arguments with LoRA support."""
|
||||
|
||||
enable_lora: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable LoRA adapter support in vLLM."},
|
||||
)
|
||||
max_lora_rank: int = field(
|
||||
default=64,
|
||||
metadata={"help": "Maximum LoRA rank supported."},
|
||||
)
|
||||
max_loras: int = field(
|
||||
default=2,
|
||||
metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."},
|
||||
)
|
||||
lora_dtype: str = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "Data type for LoRA weights."},
|
||||
)
|
||||
|
||||
|
||||
def llm_worker(
|
||||
script_args: LoRAScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
# Use batch-capable worker extension (adds batch_update_named_params + auto-close)
|
||||
worker_extension_cls="axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension",
|
||||
trust_remote_code=script_args.trust_remote_code,
|
||||
model_impl=script_args.vllm_model_impl,
|
||||
logprobs_mode="processed_logprobs",
|
||||
# LoRA
|
||||
enable_lora=script_args.enable_lora,
|
||||
max_lora_rank=script_args.max_lora_rank,
|
||||
max_loras=script_args.max_loras,
|
||||
lora_dtype=script_args.lora_dtype,
|
||||
)
|
||||
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args = command.get("args", ())
|
||||
kwargs = command.get("kwargs", {})
|
||||
|
||||
# Reconstruct LoRARequest from serialized dict (can't pickle across pipe)
|
||||
if "lora_request" in kwargs and kwargs["lora_request"] is not None:
|
||||
lr = kwargs["lora_request"]
|
||||
kwargs["lora_request"] = LoRARequest(
|
||||
lora_name=lr["lora_name"],
|
||||
lora_int_id=lr["lora_int_id"],
|
||||
lora_path=lr["lora_path"],
|
||||
load_inplace=lr.get("load_inplace", False),
|
||||
)
|
||||
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
|
||||
def main(script_args: ScriptArguments):
|
||||
"""Start vLLM workers with LoRA support and the HTTP server."""
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
# Request/Response models (defined locally like TRL's vllm_serve.main)
|
||||
class GenerateRequest(BaseModel):
|
||||
prompts: list[str]
|
||||
images: list[str] | None = None
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
logprobs: int | None = 0
|
||||
truncate_prompt_tokens: int | None = None
|
||||
structured_outputs_regex: str | None = None
|
||||
generation_kwargs: dict = PydanticField(default_factory=dict)
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[list[float]]]
|
||||
logprob_token_ids: list[list[list[int]]]
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[list[dict]]
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
logprobs: int | None = 0
|
||||
truncate_prompt_tokens: int | None = None
|
||||
structured_outputs_regex: str | None = None
|
||||
generation_kwargs: dict = PydanticField(default_factory=dict)
|
||||
chat_template_kwargs: dict = PydanticField(default_factory=dict)
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[list[float]]]
|
||||
logprob_token_ids: list[list[list[int]]]
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
world_size: int
|
||||
client_device_uuid: str
|
||||
|
||||
# Wrap plain ScriptArguments with LoRA defaults
|
||||
if not isinstance(script_args, LoRAScriptArguments):
|
||||
lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)
|
||||
for f in ScriptArguments.__dataclass_fields__:
|
||||
setattr(lora_args, f, getattr(script_args, f))
|
||||
# Apply LoRA defaults
|
||||
for f in LoRAScriptArguments.__dataclass_fields__:
|
||||
if f not in ScriptArguments.__dataclass_fields__:
|
||||
setattr(
|
||||
lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default
|
||||
)
|
||||
script_args = lora_args
|
||||
|
||||
# Spawn workers
|
||||
master_port = get_open_port()
|
||||
connections: list[Connection] = []
|
||||
processes: list[Process] = []
|
||||
for dp_rank in range(script_args.data_parallel_size):
|
||||
parent_conn, child_conn = Pipe()
|
||||
process = Process(
|
||||
target=llm_worker,
|
||||
args=(script_args, dp_rank, master_port, child_conn),
|
||||
)
|
||||
process.start()
|
||||
connections.append(parent_conn)
|
||||
processes.append(process)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
import time
|
||||
|
||||
startup_timeout = 300 # 5 minutes
|
||||
start_time = time.monotonic()
|
||||
ready: set[int] = set()
|
||||
while len(ready) < script_args.data_parallel_size:
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > startup_timeout:
|
||||
raise RuntimeError(
|
||||
f"vLLM workers failed to start within {startup_timeout}s "
|
||||
f"({len(ready)}/{script_args.data_parallel_size} ready)"
|
||||
)
|
||||
for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):
|
||||
if id(conn) in ready:
|
||||
continue
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError(
|
||||
f"vLLM worker {i} exited unexpectedly during startup"
|
||||
)
|
||||
if conn.poll():
|
||||
msg = conn.recv()
|
||||
if isinstance(msg, dict) and msg.get("status") == "ready":
|
||||
ready.add(id(conn))
|
||||
await asyncio.sleep(0.1)
|
||||
yield
|
||||
for p in processes:
|
||||
p.join(timeout=10)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA-specific endpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class SetLoRARequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_path: str
|
||||
load_inplace: bool = False
|
||||
|
||||
@app.post("/set_lora_adapter/")
|
||||
async def set_lora_adapter(request: SetLoRARequest):
|
||||
"""Register a LoRA adapter for all subsequent generate/chat calls."""
|
||||
active_lora["request"] = {
|
||||
"lora_name": request.lora_name,
|
||||
"lora_int_id": request.lora_int_id,
|
||||
"lora_path": request.lora_path,
|
||||
"load_inplace": request.load_inplace,
|
||||
}
|
||||
logger.info(
|
||||
"Set active LoRA: %s (id=%d, path=%s)",
|
||||
request.lora_name,
|
||||
request.lora_int_id,
|
||||
request.lora_path,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/clear_lora_adapter/")
|
||||
async def clear_lora_adapter():
|
||||
"""Clear active LoRA adapter (revert to base model)."""
|
||||
active_lora["request"] = None
|
||||
return {"status": "ok"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Standard endpoints (mirrors TRL's vllm_serve)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@app.get("/health/")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/get_world_size/")
|
||||
async def get_world_size():
|
||||
return {
|
||||
"world_size": script_args.tensor_parallel_size
|
||||
* script_args.data_parallel_size
|
||||
}
|
||||
|
||||
@app.post("/generate/", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
"""Generate completions with optional LoRA adapter."""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import vllm
|
||||
from packaging.version import Version
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
|
||||
prompts: list[dict[str, Any]] = []
|
||||
for prompt, image in zip(request.prompts, images, strict=True):
|
||||
row: dict[str, Any] = {"prompt": prompt}
|
||||
if image is not None:
|
||||
from PIL import Image
|
||||
|
||||
row["multi_modal_data"] = {
|
||||
"image": Image.open(BytesIO(base64.b64decode(image)))
|
||||
}
|
||||
prompts.append(row)
|
||||
|
||||
generation_kwargs = {
|
||||
"n": request.n,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"logprobs": request.logprobs,
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
|
||||
if Version(vllm.__version__) <= Version("0.10.2"):
|
||||
key = "guided_decoding"
|
||||
if request.structured_outputs_regex is not None:
|
||||
generation_kwargs[key] = GuidedDecodingParams(
|
||||
regex=request.structured_outputs_regex
|
||||
)
|
||||
else:
|
||||
generation_kwargs.setdefault(key, None)
|
||||
else:
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
key = "structured_outputs"
|
||||
if request.structured_outputs_regex is not None:
|
||||
generation_kwargs[key] = StructuredOutputsParams(
|
||||
regex=request.structured_outputs_regex
|
||||
)
|
||||
elif isinstance(generation_kwargs.get(key), dict):
|
||||
generation_kwargs[key] = StructuredOutputsParams(
|
||||
**generation_kwargs[key]
|
||||
)
|
||||
else:
|
||||
generation_kwargs.setdefault(key, None)
|
||||
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
|
||||
|
||||
for conn, chunk in zip(connections, chunked_prompts, strict=True):
|
||||
if not chunk:
|
||||
chunk = [{"prompt": "<placeholder>"}]
|
||||
kwargs = {
|
||||
"prompts": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
|
||||
]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
return {
|
||||
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
||||
"completion_ids": [
|
||||
list(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
],
|
||||
"logprobs": extract_logprobs(all_outputs)[0],
|
||||
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
||||
}
|
||||
|
||||
@app.post("/chat/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""Chat endpoint with optional LoRA adapter."""
|
||||
generation_kwargs = {
|
||||
"n": request.n,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"logprobs": request.logprobs,
|
||||
}
|
||||
generation_kwargs.update(request.generation_kwargs)
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
chunked = chunk_list(request.messages, script_args.data_parallel_size)
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [[{"role": "user", "content": "<placeholder>"}]]
|
||||
kwargs = {
|
||||
"messages": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"use_tqdm": False,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
||||
|
||||
all_outputs = [conn.recv() for conn in connections]
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
return {
|
||||
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
||||
"completion_ids": [
|
||||
list(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
],
|
||||
"logprobs": extract_logprobs(all_outputs)[0],
|
||||
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
async def init_communicator(request: InitCommunicatorRequest):
|
||||
world_size = (
|
||||
script_args.tensor_parallel_size * script_args.data_parallel_size + 1
|
||||
)
|
||||
kwargs = {
|
||||
"method": "init_communicator",
|
||||
"args": (
|
||||
request.host,
|
||||
request.port,
|
||||
world_size,
|
||||
request.client_device_uuid,
|
||||
),
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": "Initializing communicator"}
|
||||
|
||||
class UpdateWeightsRequest(BaseModel):
|
||||
name: str
|
||||
dtype: str
|
||||
shape: list[int]
|
||||
|
||||
@app.post("/update_named_param/")
|
||||
async def update_named_param(request: UpdateWeightsRequest):
|
||||
kwargs = {
|
||||
"method": "update_named_param",
|
||||
"args": (request.name, request.dtype, tuple(request.shape)),
|
||||
}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": "Updating parameter"}
|
||||
|
||||
class BatchUpdateWeightsRequest(BaseModel):
|
||||
params: list[dict]
|
||||
|
||||
@app.post("/batch_update_named_params/")
|
||||
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
|
||||
params_list = [
|
||||
(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params
|
||||
]
|
||||
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
|
||||
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(
|
||||
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
||||
)
|
||||
return {"message": f"Batch update for {len(params_list)} params"}
|
||||
|
||||
@app.post("/reset_prefix_cache/")
|
||||
async def reset_prefix_cache():
|
||||
for conn in connections:
|
||||
conn.send({"type": "call", "method": "reset_prefix_cache"})
|
||||
results = [conn.recv() for conn in connections]
|
||||
return {"message": f"Reset prefix cache: {all(results)}"}
|
||||
|
||||
@app.post("/close_communicator/")
|
||||
async def close_communicator():
|
||||
kwargs = {"method": "close_communicator"}
|
||||
for conn in connections:
|
||||
conn.send(
|
||||
{
|
||||
"type": "fire_and_forget",
|
||||
"method": "collective_rpc",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return {"message": "Closing communicator"}
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=script_args.host,
|
||||
port=script_args.port,
|
||||
log_level=script_args.log_level,
|
||||
access_log=True,
|
||||
)
|
||||
158
src/axolotl/scripts/vllm_worker_ext.py
Normal file
158
src/axolotl/scripts/vllm_worker_ext.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Extended vLLM worker extension with batch weight sync support.
|
||||
|
||||
Subclasses TRL's WeightSyncWorkerExtension to add:
|
||||
- batch_update_named_params: receives multiple params in one call
|
||||
- Auto-close stale communicator on re-init
|
||||
- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params,
|
||||
including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers import is_torch_xpu_available
|
||||
except ImportError:
|
||||
is_torch_xpu_available = lambda: False # noqa: E731
|
||||
|
||||
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stacked param name mapping: shard_name -> (packed_name, shard_order)
|
||||
_STACKED_PARAMS = {
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
|
||||
class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
|
||||
"""Worker extension that adds batch weight update and direct weight setting."""
|
||||
|
||||
def init_communicator(self, host, port, world_size, client_device_uuid):
|
||||
"""Auto-close stale communicator before re-initializing."""
|
||||
if self.communicator is not None:
|
||||
self.close_communicator()
|
||||
super().init_communicator(host, port, world_size, client_device_uuid)
|
||||
|
||||
def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None:
|
||||
"""Directly copy weight data into the model, handling stacked params.
|
||||
|
||||
Bypasses model.load_weights() which may fail on vLLM 0.17's new
|
||||
module-tree weight loader for stacked params (qkv_proj, gate_up_proj).
|
||||
|
||||
Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the
|
||||
parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``).
|
||||
"""
|
||||
model = self.model_runner.model
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
# Check if this is a simple direct param (exists as-is)
|
||||
if name in params_dict:
|
||||
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
|
||||
return
|
||||
|
||||
# Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight
|
||||
parts_bl = name.rsplit(".", 1)
|
||||
if len(parts_bl) == 2:
|
||||
base_layer_name = f"{parts_bl[0]}.base_layer.{parts_bl[1]}"
|
||||
if base_layer_name in params_dict:
|
||||
params_dict[base_layer_name].data.copy_(
|
||||
weight.to(params_dict[base_layer_name].dtype)
|
||||
)
|
||||
return
|
||||
|
||||
# Handle stacked params: e.g. "model.layers.0.self_attn.q_proj.weight"
|
||||
# -> "model.layers.0.self_attn.qkv_proj.weight" with shard offset
|
||||
parts = name.rsplit(".", 2) # [prefix, layer_name, suffix]
|
||||
if len(parts) == 3:
|
||||
prefix, layer_name, suffix = parts
|
||||
if layer_name in _STACKED_PARAMS:
|
||||
packed_name, shard_idx = _STACKED_PARAMS[layer_name]
|
||||
for packed_full in [
|
||||
f"{prefix}.{packed_name}.{suffix}",
|
||||
f"{prefix}.{packed_name}.base_layer.{suffix}",
|
||||
]:
|
||||
if packed_full not in params_dict:
|
||||
continue
|
||||
param = params_dict[packed_full]
|
||||
# Navigate to the packed module to find shard sizes
|
||||
module_path = packed_full.rsplit(".", 1)[0] # strip .weight/.bias
|
||||
if ".base_layer" in module_path:
|
||||
module_path = module_path.replace(".base_layer", "")
|
||||
module = model
|
||||
for attr in module_path.split("."):
|
||||
module = getattr(module, attr, None)
|
||||
if module is None:
|
||||
break
|
||||
# LoRA wrappers don't have output_sizes directly;
|
||||
# check base_layer for the underlying parallel linear
|
||||
if module is not None and not hasattr(module, "output_sizes"):
|
||||
base = getattr(module, "base_layer", None)
|
||||
if base is not None and hasattr(base, "output_sizes"):
|
||||
module = base
|
||||
if module is not None and hasattr(module, "output_sizes"):
|
||||
tp_size = getattr(module, "tp_size", 1)
|
||||
sizes = [s // tp_size for s in module.output_sizes]
|
||||
offset = sum(sizes[:shard_idx])
|
||||
shard_size = sizes[shard_idx]
|
||||
param.data[offset : offset + shard_size].copy_(
|
||||
weight.to(param.dtype)
|
||||
)
|
||||
return
|
||||
|
||||
# Fallback: try load_weights (may work for non-stacked params)
|
||||
logger.warning("Falling back to load_weights for param: %s", name)
|
||||
model.load_weights(weights=[(name, weight)])
|
||||
|
||||
def update_named_param(self, name, dtype, shape):
|
||||
"""Override to use _direct_set_weight instead of load_weights."""
|
||||
if self.communicator is None:
|
||||
raise RuntimeError("Communicator not initialized.")
|
||||
|
||||
dtype = getattr(torch, dtype.split(".")[-1])
|
||||
weight = torch.empty(shape, dtype=dtype, device=self.device)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weight, root=self.client_rank)
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.broadcast(weight, src=self.client_rank)
|
||||
self.communicator.group.barrier()
|
||||
|
||||
self._direct_set_weight(name, weight)
|
||||
|
||||
def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]):
|
||||
"""Receive and apply multiple weight tensors in sequence.
|
||||
|
||||
Args:
|
||||
params_list: List of (name, dtype_str, shape) tuples.
|
||||
"""
|
||||
if self.communicator is None:
|
||||
raise RuntimeError("Communicator not initialized.")
|
||||
|
||||
weights_to_load = []
|
||||
for name, dtype_str, shape in params_list:
|
||||
dtype = getattr(torch, dtype_str.split(".")[-1])
|
||||
weight = torch.empty(shape, dtype=dtype, device=self.device)
|
||||
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.broadcast(weight, root=self.client_rank)
|
||||
else:
|
||||
self.communicator.broadcast(weight, src=self.client_rank)
|
||||
|
||||
weights_to_load.append((name, weight))
|
||||
|
||||
# Single barrier after all broadcasts
|
||||
if is_torch_xpu_available():
|
||||
self.communicator.barrier()
|
||||
else:
|
||||
self.communicator.group.barrier()
|
||||
|
||||
# Load weights using direct set (handles stacked params)
|
||||
for name, weight in weights_to_load:
|
||||
self._direct_set_weight(name, weight)
|
||||
@@ -69,7 +69,6 @@ def setup_model_and_tokenizer(
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
|
||||
@@ -59,7 +59,6 @@ class DynamicCheckpointCallback(TrainerCallback):
|
||||
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
|
||||
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
|
||||
f" • Check interval: every {self.check_interval} steps",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
def on_step_end(
|
||||
@@ -89,12 +88,10 @@ class DynamicCheckpointCallback(TrainerCallback):
|
||||
LOG.info(
|
||||
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
|
||||
f"at step {state.global_step}",
|
||||
main_process_only=True,
|
||||
)
|
||||
except OSError as exc:
|
||||
LOG.warning(
|
||||
f"Failed to delete trigger file: {exc}",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
if self.should_save_checkpoint:
|
||||
@@ -127,6 +124,5 @@ class DynamicCheckpointCallback(TrainerCallback):
|
||||
control.should_save = True
|
||||
LOG.info(
|
||||
f"Saving dynamic checkpoint at step {state.global_step}",
|
||||
main_process_only=True,
|
||||
)
|
||||
return control
|
||||
|
||||
@@ -17,6 +17,8 @@ from transformers import (
|
||||
class PytorchProfilerCallback(TrainerCallback):
|
||||
"""
|
||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||
|
||||
Also runs torch.profiler to produce a Chrome trace for timing analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
if profiler_steps_start == 0:
|
||||
# start recording memory allocations before everything is allocated, because if we start
|
||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
profiler_steps_start = -1
|
||||
self.profiler_steps_start = profiler_steps_start
|
||||
self._profiler = None
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
if state.global_step == self.profiler_steps_start:
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
|
||||
# Start torch.profiler on the first profiled step
|
||||
if state.global_step == max(self.profiler_steps_start, 0):
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
profiler.__enter__()
|
||||
self._profiler = profiler
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
# Stop and export torch.profiler trace
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
@@ -84,7 +84,7 @@ def resolve_dtype(cfg):
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
if cfg.tf32:
|
||||
if cfg.tf32 is True:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if is_torch_greater_or_equal("2.9.0"):
|
||||
torch.backends.fp32_precision = "tf32"
|
||||
@@ -195,6 +195,15 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
|
||||
if callable(getattr(model_config, "get_text_config", None)):
|
||||
text_config = model_config.get_text_config()
|
||||
if (
|
||||
hasattr(text_config, "model_type")
|
||||
and text_config.model_type != model_config.model_type
|
||||
):
|
||||
cfg.model_config_type_text = text_config.model_type
|
||||
|
||||
# figure out if the model is llama
|
||||
cfg.is_llama_derived_model = (
|
||||
(
|
||||
|
||||
@@ -348,7 +348,9 @@ def _load_raw_datasets(
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
else:
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
if cfg.sample_packing:
|
||||
if (split == "train" and cfg.sample_packing) or (
|
||||
split == "test" and cfg.eval_sample_packing
|
||||
):
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
# Deduplicate before saving so the saved dataset is already de-duplicated
|
||||
|
||||
@@ -474,13 +474,11 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
|
||||
):
|
||||
LOG.info(
|
||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||
main_process_only=True,
|
||||
)
|
||||
return load_from_disk(str(prepared_ds_path))
|
||||
|
||||
LOG.info(
|
||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
||||
|
||||
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
||||
except:
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
# int4 weight config imports will fail on machines with fbgemm-gpu installed
|
||||
@@ -38,7 +38,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||
|
||||
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
|
||||
except:
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
|
||||
@@ -407,9 +407,11 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "No AMP (automatic mixed precision)"},
|
||||
) # for non-AMP cases
|
||||
tf32: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"},
|
||||
tf32: Literal["auto"] | bool | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={
|
||||
"description": "bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere"
|
||||
},
|
||||
)
|
||||
float32: bool | None = None
|
||||
|
||||
@@ -1218,6 +1220,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tf32(self):
|
||||
if self.tf32 == "auto":
|
||||
self.tf32 = self.capabilities.tf32
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fp8(self):
|
||||
if self.fp8 and not self.capabilities.fp8:
|
||||
|
||||
@@ -10,6 +10,7 @@ class GPUCapabilities(BaseModel):
|
||||
|
||||
bf16: bool = Field(default=False)
|
||||
fp8: bool = Field(default=False)
|
||||
tf32: bool = Field(default=False)
|
||||
n_gpu: int = Field(default=1)
|
||||
n_node: int = Field(default=1)
|
||||
compute_capability: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -189,3 +189,125 @@ class TRLConfig(BaseModel):
|
||||
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
|
||||
},
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
use_data_producer: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Use the GRPODataProducer protocol for online data generation."
|
||||
},
|
||||
)
|
||||
async_prefetch: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Generate rollouts in a background thread while training on the previous rollout."
|
||||
},
|
||||
)
|
||||
prefetch_depth: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of rollouts to prefetch ahead of training."
|
||||
},
|
||||
)
|
||||
vllm_sync_interval: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
|
||||
},
|
||||
)
|
||||
streaming_partial_batch: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Score prompt groups incrementally instead of the full batch at once."
|
||||
},
|
||||
)
|
||||
streaming_min_groups: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum prompt groups to score per streaming chunk."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_correction: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_mode: (
|
||||
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"]
|
||||
| None
|
||||
) = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
|
||||
},
|
||||
)
|
||||
vllm_importance_sampling_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Cap C for IS ratio clipping/masking."},
|
||||
)
|
||||
off_policy_mask_threshold: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
|
||||
},
|
||||
)
|
||||
use_bias_correction_kl: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Apply IS correction to KL divergence term."},
|
||||
)
|
||||
|
||||
reward_num_workers: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
|
||||
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
|
||||
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
|
||||
},
|
||||
)
|
||||
replay_buffer_size: int = Field(
|
||||
default=0,
|
||||
json_schema_extra={
|
||||
"description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
|
||||
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
|
||||
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
replay_recompute_logps: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "When True (default), recompute old_per_token_logps for replayed groups using the current "
|
||||
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
|
||||
"Only relevant when replay_buffer_size > 0."
|
||||
},
|
||||
)
|
||||
reroll_start_fraction: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
|
||||
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
|
||||
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
reroll_max_groups: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
|
||||
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
|
||||
},
|
||||
)
|
||||
skip_zero_advantage_batches: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
|
||||
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
|
||||
"logged with skipped_zero_adv_batches=1 for monitoring."
|
||||
},
|
||||
)
|
||||
vllm_lora_sync: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. "
|
||||
"Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -128,7 +128,6 @@ class DatasetValidationMixin:
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`",
|
||||
main_process_only=True,
|
||||
)
|
||||
data["eval_sample_packing"] = True
|
||||
|
||||
@@ -254,6 +253,23 @@ class TrainingValidationMixin:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_reward_model_defaults(cls, data):
|
||||
if data.get("reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 1
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForSequenceClassification"
|
||||
|
||||
if data.get("process_reward_model"):
|
||||
if data.get("num_labels") is None:
|
||||
data["num_labels"] = 2
|
||||
if not (data.get("type_of_model") or data.get("model_type")):
|
||||
data["model_type"] = "AutoModelForTokenClassification"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
@@ -676,20 +692,6 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_rl(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("rl"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with RL at the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_trust_remote_code(cls, data):
|
||||
|
||||
@@ -57,3 +57,10 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Reasoning parser for VLLM"},
|
||||
)
|
||||
serve_module: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' "
|
||||
"for native LoRA support, or leave None for default TRL serve."
|
||||
},
|
||||
)
|
||||
|
||||
220
tests/core/test_async_grpo.py
Normal file
220
tests/core/test_async_grpo.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Unit tests for async GRPO"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestReplayBuffer(unittest.TestCase):
|
||||
"""Tests for ReplayBuffer edge cases."""
|
||||
|
||||
def test_add_noop_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_add_noop_when_max_size_negative(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=-1)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_sample_returns_none_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_sample_returns_none_when_empty(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=5)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_normal_add_and_sample(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=3)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3})
|
||||
self.assertEqual(len(buf), 3)
|
||||
result = buf.sample(1)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_replaces_lowest_when_full(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=2)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3}) # should replace score=1.0
|
||||
self.assertEqual(len(buf), 2)
|
||||
scores = sorted(item[0] for item in buf._heap)
|
||||
self.assertEqual(scores, [2.0, 3.0])
|
||||
|
||||
|
||||
class TestGRPOStrategyConflict(unittest.TestCase):
|
||||
"""Tests for sequence_parallel + async_grpo conflict detection."""
|
||||
|
||||
def test_raises_on_both_enabled(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)
|
||||
self.assertIn("sequence_parallel", str(ctx.exception))
|
||||
self.assertIn("async_grpo", str(ctx.exception))
|
||||
|
||||
def test_sequence_parallel_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
)
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)
|
||||
|
||||
def test_async_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)
|
||||
self.assertIs(cls, AxolotlAsyncGRPOTrainer)
|
||||
|
||||
def test_neither(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOTrainer)
|
||||
|
||||
|
||||
class TestDequantizeFP8TailBlocks(unittest.TestCase):
|
||||
"""Tests for FP8 dequantization with non-divisible dimensions."""
|
||||
|
||||
def test_exact_divisible_shape(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (256, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_rows(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
# 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with
|
||||
# tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).
|
||||
# Use 131 rows with 2 scale blocks to trigger tail handling.
|
||||
W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (131, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_cols(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (128, 200))
|
||||
|
||||
def test_scalar_scale(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (64, 64))
|
||||
|
||||
|
||||
class TestLoraFP8Guard(unittest.TestCase):
|
||||
"""Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights."""
|
||||
|
||||
def test_non_fp8_weight_skips_scale_inv(self):
|
||||
"""Non-FP8 weight should NOT pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely
|
||||
|
||||
# Use a real tensor for weight (bf16, no quant_state attr)
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)
|
||||
base_layer.bias = None
|
||||
base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16
|
||||
|
||||
proj.base_layer = base_layer
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
# quant_state should be None since weight is bf16, not FP8
|
||||
self.assertIsNone(quant_state)
|
||||
|
||||
def test_fp8_weight_uses_scale_inv(self):
|
||||
"""FP8 weight should pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock()
|
||||
proj.base_layer = base_layer
|
||||
|
||||
# FP8 weight
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
base_layer.bias = None
|
||||
scale_inv = torch.ones(1)
|
||||
base_layer.weight_scale_inv = scale_inv
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
self.assertIs(quant_state, scale_inv)
|
||||
|
||||
|
||||
class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
"""Test that validate_quantization_for_training is restored after trainer creation."""
|
||||
|
||||
def test_patch_restored_on_success(self):
|
||||
"""Monkeypatch should be restored even after successful trainer creation."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
# After the build() method runs, original should be restored.
|
||||
# We can't easily test the full build(), but we can test the pattern.
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
pass # simulate trainer_cls() succeeding
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
def test_patch_restored_on_error(self):
|
||||
"""Monkeypatch should be restored even if trainer creation raises."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user