Compare commits
148 Commits
coderabbit
...
fix/cp-was
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
255c5b90ca | ||
|
|
038ffe3f26 | ||
|
|
c13cb7c853 | ||
|
|
b3823cc6b0 | ||
|
|
113d275bd9 | ||
|
|
7920fe74ec | ||
|
|
1fc86d5295 | ||
|
|
bb483ad4c4 | ||
|
|
163bd4dd5a | ||
|
|
f291ac029c | ||
|
|
5ef3f28340 | ||
|
|
999b3fec2e | ||
|
|
8f3fb517b3 | ||
|
|
830e9f7eaf | ||
|
|
d230cbbde3 | ||
|
|
a098df527b | ||
|
|
7da5f94379 | ||
|
|
4a5876df7a | ||
|
|
defee62d99 | ||
|
|
f56efdb4ab | ||
|
|
d8a646c80d | ||
|
|
a806704e94 | ||
|
|
d8a05744d7 | ||
|
|
ff77fa2488 | ||
|
|
e1ff756245 | ||
|
|
083c5a0421 | ||
|
|
79908b3c6e | ||
|
|
819b157c7b | ||
|
|
fccc712dae | ||
|
|
23ad40bdd5 | ||
|
|
cf4d550c88 | ||
|
|
43b1c80aa6 | ||
|
|
a36aaa70ce | ||
|
|
80f7088ad1 | ||
|
|
46b9f40f2a | ||
|
|
8f19169eb0 | ||
|
|
876941ffd0 | ||
|
|
d65e1b960c | ||
|
|
0a23ae08f7 | ||
|
|
fc2d63ee5f | ||
|
|
c119382337 | ||
|
|
6c8c73e5a4 | ||
|
|
a260d330ed | ||
|
|
da17c7c0d9 | ||
|
|
cada93cee5 | ||
|
|
56162f71db | ||
|
|
6c44afaea1 | ||
|
|
234931d512 | ||
|
|
6a8baf8fa7 | ||
|
|
1eaf4d7418 | ||
|
|
4b8bc52424 | ||
|
|
28cc085283 | ||
|
|
8e2a102cca | ||
|
|
753906cfc7 | ||
|
|
b6b8db805a | ||
|
|
653f90be25 | ||
|
|
945c8aeb10 | ||
|
|
e672d37f33 | ||
|
|
77828d3559 | ||
|
|
4272817109 | ||
|
|
474208b794 | ||
|
|
444020b332 | ||
|
|
aa88c2e30b | ||
|
|
f447bce1db | ||
|
|
7f23b302d1 | ||
|
|
18f26c19ef | ||
|
|
2b6f4a6c9b | ||
|
|
8f54b4eb25 | ||
|
|
a131e4d0e5 | ||
|
|
1791d87b6f | ||
|
|
b40803da51 | ||
|
|
68f1b7004c | ||
|
|
08441fed17 | ||
|
|
86ca1e27c0 | ||
|
|
5ed455715e | ||
|
|
3f30572d4a | ||
|
|
43d60c7439 | ||
|
|
0ea252d392 | ||
|
|
29722dec60 | ||
|
|
7fbedbd300 | ||
|
|
145ffc9be1 | ||
|
|
4f1b5ad29f | ||
|
|
d6a2532dd7 | ||
|
|
5eb265513c | ||
|
|
06ac407b92 | ||
|
|
4e22cf0651 | ||
|
|
a4ee56c315 | ||
|
|
c67cbcb0f5 | ||
|
|
a2da852576 | ||
|
|
37e9da7a53 | ||
|
|
ed7105dba7 | ||
|
|
b6d3653f74 | ||
|
|
fcc4cfdb63 | ||
|
|
97a4f28511 | ||
|
|
86a5803212 | ||
|
|
530a0c0bf0 | ||
|
|
0343a72cc9 | ||
|
|
236dad3bb7 | ||
|
|
be00978bc2 | ||
|
|
3738978394 | ||
|
|
6132a30cda | ||
|
|
3dd86d35b8 | ||
|
|
dd9ebaeba1 | ||
|
|
fc4e37920b | ||
|
|
a531e9d946 | ||
|
|
04328aeb97 | ||
|
|
d0d26d5064 | ||
|
|
8623dd8a72 | ||
|
|
8cd75cff9f | ||
|
|
8ab9d9ea88 | ||
|
|
6e42def14b | ||
|
|
c413480b35 | ||
|
|
8f25124269 | ||
|
|
790df757cb | ||
|
|
d282f32481 | ||
|
|
6331e4a130 | ||
|
|
1410e4474e | ||
|
|
dc77b5bf42 | ||
|
|
359b7ad85e | ||
|
|
258ce8d4fa | ||
|
|
3e0bbd33ec | ||
|
|
4ae6f766ad | ||
|
|
e7f0d4ba5b | ||
|
|
7bf6f70e96 | ||
|
|
8aab807e67 | ||
|
|
ee59e4de97 | ||
|
|
4e61b8aa23 | ||
|
|
b26ba3a5cb | ||
|
|
afe18ace35 | ||
|
|
2b199f9915 | ||
|
|
e73dab6df9 | ||
|
|
f45a97a9ff | ||
|
|
11c0b5b256 | ||
|
|
66a3de3629 | ||
|
|
a6080df73c | ||
|
|
4f5e8a328a | ||
|
|
418933f0d1 | ||
|
|
372f664c63 | ||
|
|
97f1b1758d | ||
|
|
f2155eaf79 | ||
|
|
92ee4256f7 | ||
|
|
efeb5a4e41 | ||
|
|
faaff6c792 | ||
|
|
43cef27458 | ||
|
|
07c41a6c2a | ||
|
|
bbd3486f57 | ||
|
|
3750d7dd64 | ||
|
|
2197b0bf89 |
9
.github/CONTRIBUTING.md
vendored
9
.github/CONTRIBUTING.md
vendored
@@ -68,7 +68,12 @@ You can skip certain CI checks by including specific keywords in your commit mes
|
||||
|
||||
### Code Style
|
||||
|
||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
Use the pre-commit linter to ensure that your code is formatted consistently.
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Commit Messages
|
||||
|
||||
@@ -78,6 +83,6 @@ Write clear and concise commit messages that briefly describe the changes made i
|
||||
|
||||
- [GitHub Help](https://help.github.com/)
|
||||
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
|
||||
- [{codestyle}]({URLofCodestyle})
|
||||
- [Ruff](https://docs.astral.sh/ruff/)
|
||||
|
||||
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
|
||||
|
||||
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,6 +15,11 @@
|
||||
<!--- Include details of your testing environment, tests ran to see how -->
|
||||
<!--- your change affects other areas of the code, etc. -->
|
||||
|
||||
## AI Usage Disclaimer
|
||||
|
||||
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
||||
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
||||
|
||||
## Screenshots (if appropriate)
|
||||
|
||||
## Types of changes
|
||||
|
||||
164
.github/workflows/base.yml
vendored
164
.github/workflows/base.yml
vendored
@@ -15,37 +15,21 @@ on:
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -53,6 +37,15 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -60,6 +53,31 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -67,6 +85,23 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -92,17 +127,19 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
@@ -117,24 +154,12 @@ jobs:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -142,6 +167,7 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -149,6 +175,47 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-uv-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -156,6 +223,23 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -166,17 +250,19 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,6 +12,9 @@ jobs:
|
||||
build-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Quarto
|
||||
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -13,6 +13,9 @@ on:
|
||||
- ".pre-commit-config.yaml"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
|
||||
299
.github/workflows/main.yml
vendored
299
.github/workflows/main.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -15,37 +18,49 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -71,6 +86,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -85,50 +101,134 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
build-axolotl-uv:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||
- name: Build and export to Docker
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -153,6 +253,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -163,29 +264,98 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-uv:
|
||||
needs: build-axolotl-uv
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-cloud-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
runs-on: axolotl-gpu-runner
|
||||
@@ -212,6 +382,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
|
||||
44
.github/workflows/multi-gpu-e2e.yml
vendored
44
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
@@ -19,6 +20,12 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
@@ -26,27 +33,33 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -59,7 +72,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -68,8 +81,9 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
modal run -m cicd.multigpu
|
||||
|
||||
23
.github/workflows/nightlies.yml
vendored
23
.github/workflows/nightlies.yml
vendored
@@ -5,6 +5,9 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -12,16 +15,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -64,16 +67,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -5,6 +5,8 @@ on:
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
13
.github/workflows/preview-docs.yml
vendored
13
.github/workflows/preview-docs.yml
vendored
@@ -11,22 +11,21 @@ on:
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
deployments: write
|
||||
issues: write
|
||||
discussions: write
|
||||
pages: write
|
||||
contents: read
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
preview:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
15
.github/workflows/pypi.yml
vendored
15
.github/workflows/pypi.yml
vendored
@@ -3,9 +3,11 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -28,7 +30,8 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
contents: read
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
@@ -40,17 +43,17 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==23.2
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Update version in setup.py
|
||||
- name: Update version in VERSION file
|
||||
run: |
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
|
||||
|
||||
- name: Build a source dist
|
||||
run: |
|
||||
|
||||
60
.github/workflows/tests-nightly.yml
vendored
60
.github/workflows/tests-nightly.yml
vendored
@@ -3,6 +3,13 @@ on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -18,15 +25,26 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -37,7 +55,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,7 +66,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -99,19 +117,26 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -123,7 +148,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -132,9 +157,11 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
docker-e2e-multigpu-tests:
|
||||
@@ -148,10 +175,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 2
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
@@ -165,7 +192,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -175,7 +202,8 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
|
||||
141
.github/workflows/tests.yml
vendored
141
.github/workflows/tests.yml
vendored
@@ -28,6 +28,9 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
|
||||
@@ -46,16 +49,32 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
# needs: [preload-cache]
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -66,12 +85,13 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -81,7 +101,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -109,12 +129,15 @@ jobs:
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
df -h
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
@@ -122,6 +145,9 @@ jobs:
|
||||
df -h
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
@@ -134,12 +160,18 @@ jobs:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
@@ -149,12 +181,13 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -164,7 +197,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -192,16 +225,19 @@ jobs:
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: hf cache ls
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
needs: [pre-commit]
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
skip: ${{ steps.compute.outputs.skip }}
|
||||
@@ -237,16 +273,16 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
@@ -260,7 +296,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -270,9 +306,10 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -292,18 +329,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -314,7 +339,19 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -327,7 +364,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -338,9 +375,10 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -354,10 +392,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -370,7 +408,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -380,7 +418,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.7
|
||||
rev: v0.15.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.0
|
||||
rev: v1.19.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.9.2
|
||||
rev: 1.9.4
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -123,7 +123,7 @@ datasets:
|
||||
| --------------------------------- | -------------------------- | ----------------------------------- |
|
||||
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
||||
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
||||
| `dataset_processes` | `4` | Number of preprocessing processes |
|
||||
| `dataset_num_proc` | `4` | Number of preprocessing processes |
|
||||
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
||||
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
||||
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
||||
|
||||
@@ -39,7 +39,6 @@
|
||||
# type: # linear | dynamic
|
||||
# factor: # float
|
||||
|
||||
|
||||
# # Whether you are training a 4-bit GPTQ quantized model
|
||||
# gptq: true
|
||||
# gptq_groupsize: 128 # group size
|
||||
@@ -107,7 +106,7 @@
|
||||
# push_dataset_to_hub: # repo path
|
||||
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# # if not set.
|
||||
# dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# dataset_num_proc: # defaults to os.cpu_count() if not set
|
||||
# # push checkpoints to hub
|
||||
# hub_model_id: # repo path to push finetuned model
|
||||
# # how to push checkpoints to hub
|
||||
@@ -224,9 +223,6 @@
|
||||
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# # Save model as safetensors (require safetensors package)
|
||||
# save_safetensors:
|
||||
|
||||
# # Whether to mask out or include the human's prompt from the training labels
|
||||
# train_on_inputs: false
|
||||
# # Group similarly sized data to minimize padding.
|
||||
@@ -352,8 +348,6 @@
|
||||
# # Allow overwrite yml config using from cli
|
||||
# strict:
|
||||
|
||||
|
||||
|
||||
base_model: ${BASE_MODEL}
|
||||
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
|
||||
base_model_config: ${BASE_MODEL_CONFIG}
|
||||
@@ -412,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
|
||||
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
|
||||
dataset_prepared_path: ${DATASET_PREPARED_PATH}
|
||||
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
|
||||
dataset_processes: ${DATASET_PROCESSES}
|
||||
dataset_num_proc: ${DATASET_NUM_PROC}
|
||||
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
|
||||
hub_model_id: ${HUB_MODEL_ID}
|
||||
hub_strategy: ${HUB_STRATEGY}
|
||||
@@ -512,7 +506,6 @@ profiler_steps: ${PROFILER_STEPS}
|
||||
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
|
||||
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
|
||||
|
||||
save_safetensors: ${SAVE_SAFETENSORS}
|
||||
train_on_inputs: ${TRAIN_ON_INPUTS}
|
||||
group_by_length: ${GROUP_BY_LENGTH}
|
||||
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
|
||||
|
||||
46
README.md
46
README.md
@@ -29,25 +29,35 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2026/03:
|
||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||
- 2026/02:
|
||||
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
|
||||
- 2026/01:
|
||||
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
|
||||
- 2025/12:
|
||||
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
@@ -62,10 +72,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
@@ -77,7 +87,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.7.1
|
||||
- PyTorch ≥2.8.0
|
||||
|
||||
### Google Colab
|
||||
|
||||
@@ -88,7 +98,7 @@ Features:
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
|
||||
46
_quarto.yml
46
_quarto.yml
@@ -1,6 +1,8 @@
|
||||
project:
|
||||
type: website
|
||||
pre-render: docs/scripts/generate_config_docs.py
|
||||
pre-render:
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- docs/scripts/generate_examples_docs.py
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
@@ -240,6 +242,46 @@ website:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- section: "Model Guides"
|
||||
contents:
|
||||
- docs/models/kimi-linear.qmd
|
||||
- docs/models/plano.qmd
|
||||
- docs/models/mimo.qmd
|
||||
- docs/models/internvl3_5.qmd
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
- docs/models/ministral3/think.qmd
|
||||
- docs/models/ministral3/vision.qmd
|
||||
- section: "Magistral"
|
||||
contents:
|
||||
- docs/models/magistral.qmd
|
||||
- docs/models/magistral/think.qmd
|
||||
- docs/models/magistral/vision.qmd
|
||||
- docs/models/ministral.qmd
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
- docs/models/qwen3.qmd
|
||||
- docs/models/gemma3n.qmd
|
||||
- docs/models/apertus.qmd
|
||||
- docs/models/gpt-oss.qmd
|
||||
- docs/models/seed-oss.qmd
|
||||
- docs/models/phi.qmd
|
||||
- docs/models/smolvlm2.qmd
|
||||
- docs/models/granite4.qmd
|
||||
- docs/models/LiquidAI.qmd
|
||||
- docs/models/hunyuan.qmd
|
||||
- docs/models/jamba.qmd
|
||||
- docs/models/orpheus.qmd
|
||||
|
||||
- docs/cli.qmd
|
||||
- docs/telemetry.qmd
|
||||
- docs/config-reference.qmd
|
||||
@@ -278,6 +320,7 @@ website:
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
- docs/attention.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -288,6 +331,7 @@ website:
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
- docs/expert_quantization.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
208
benchmarks/bench_entropy.py
Normal file
208
benchmarks/bench_entropy.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
|
||||
"""Original chunked implementation (reference)."""
|
||||
original_shape = logits.shape[:-1]
|
||||
num_classes = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, num_classes)
|
||||
entropies = []
|
||||
for chunk in flat_logits.split(chunk_size, dim=0):
|
||||
logps = F.log_softmax(chunk, dim=-1)
|
||||
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
||||
entropies.append(chunk_entropy)
|
||||
return torch.cat(entropies, dim=0).reshape(original_shape)
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, logits, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
out = fn(logits, chunk_size=128)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
del out
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, logits, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(logits, chunk_size=128)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_contiguous():
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(1, 16384),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
t_orig = profile_time(entropy_from_logits_original, logits)
|
||||
t_triton = profile_time(entropy_from_logits, logits)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(entropy_from_logits_original, logits)
|
||||
m_triton = profile_memory(entropy_from_logits, logits)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_noncontiguous():
|
||||
print("\n" + "=" * 60)
|
||||
print(
|
||||
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(4, 2048, "transpose"),
|
||||
(4, 8192, "transpose"),
|
||||
(8, 2048, "transpose"),
|
||||
(4, 4096, "slice_batch"),
|
||||
]
|
||||
|
||||
for B, L, method in configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
if method == "transpose":
|
||||
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw.transpose(0, 1)
|
||||
raw_gb = L * B * V * 2 / 1e9
|
||||
elif method == "slice_batch":
|
||||
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw[::2]
|
||||
raw_gb = B * 2 * L * V * 2 / 1e9
|
||||
else:
|
||||
continue
|
||||
|
||||
if raw_gb > 28:
|
||||
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
|
||||
del raw, logits_nc
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
def original_with_copy(logits, chunk_size=128):
|
||||
return entropy_from_logits_original(
|
||||
logits.contiguous(), chunk_size=chunk_size
|
||||
)
|
||||
|
||||
t_orig = profile_time(original_with_copy, logits_nc)
|
||||
t_triton = profile_time(entropy_from_logits, logits_nc)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" orig+copy: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton-strided:{fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(original_with_copy, logits_nc)
|
||||
m_triton = profile_memory(entropy_from_logits, logits_nc)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" orig+copy: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton-strided:{fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del raw, logits_nc
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_contiguous()
|
||||
benchmark_noncontiguous()
|
||||
284
benchmarks/bench_scattermoe_lora.py
Normal file
284
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
191
benchmarks/bench_selective_logsoftmax.py
Normal file
191
benchmarks/bench_selective_logsoftmax.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
selective_log_softmax_original,
|
||||
)
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
fn(*args)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(*args)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_forward():
|
||||
print("=" * 60)
|
||||
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
||||
t_triton = profile_time(selective_log_softmax, (logits, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
||||
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_backward():
|
||||
print("\n" + "=" * 60)
|
||||
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
def fwd_bwd_original(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax_original(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
def fwd_bwd_triton(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 20:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_orig = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
||||
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" FWD+BWD TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
||||
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" FWD+BWD MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits_orig, logits_tri, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_forward()
|
||||
benchmark_backward()
|
||||
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -31,8 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -32,7 +32,8 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -3,6 +3,13 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/e2e/ \
|
||||
|
||||
@@ -17,7 +17,8 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
@@ -27,8 +28,11 @@ df_args = {
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v --durations=10 -n2 \
|
||||
pytest -v --durations=10 -n2 --maxfail=3 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
|
||||
@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
@@ -20,13 +21,18 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py | sh && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="118"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
@@ -22,11 +24,17 @@ RUN apt-get update \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MINICONDA_ARCH="x86_64"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
MINICONDA_ARCH="aarch64"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi \
|
||||
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
@@ -35,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
@@ -51,8 +59,18 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
|
||||
30
docker/Dockerfile-cloud-uv
Normal file
30
docker/Dockerfile-cloud-uv
Normal file
@@ -0,0 +1,30 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl-uv:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN uv pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
48
docker/Dockerfile-uv
Normal file
48
docker/Dockerfile-uv
Normal file
@@ -0,0 +1,48 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
# fix so that git fetch/pull from remote works with shallow clone
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch && \
|
||||
git config --global credential.helper store
|
||||
|
||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||
@@ -2,9 +2,11 @@ ARG CUDA_VERSION="12.6.3"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.6.0"
|
||||
ARG CUDA="126"
|
||||
@@ -31,12 +33,25 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -3,3 +3,5 @@ _site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
config-reference.qmd
|
||||
models/**/*.qmd
|
||||
models/**/*.html
|
||||
|
||||
@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
|
||||
Download a base model using the Hugging Face CLI:
|
||||
|
||||
```bash
|
||||
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
```
|
||||
|
||||
### 10. Create Axolotl Configuration
|
||||
|
||||
178
docs/attention.qmd
Normal file
178
docs/attention.qmd
Normal file
@@ -0,0 +1,178 @@
|
||||
---
|
||||
title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention
|
||||
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
No sample packing support!
|
||||
|
||||
:::
|
||||
86
docs/checkpoint_saving.qmd
Normal file
86
docs/checkpoint_saving.qmd
Normal file
@@ -0,0 +1,86 @@
|
||||
---
|
||||
title: "Checkpoint Saving"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 2
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
||||
|
||||
## File-Based Checkpoint Trigger
|
||||
|
||||
### Configuration
|
||||
|
||||
Enable in your config:
|
||||
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 100 # Optional: check every N steps (default: 100)
|
||||
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `enabled`: `true` to enable (required)
|
||||
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
||||
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
||||
2. When detected, file is deleted and checkpoint is saved
|
||||
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
||||
|
||||
### Usage
|
||||
|
||||
**Command line:**
|
||||
```bash
|
||||
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||
```
|
||||
|
||||
**Programmatic:**
|
||||
```python
|
||||
from pathlib import Path
|
||||
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
||||
```
|
||||
|
||||
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
||||
|
||||
**Custom filename:**
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
trigger_file_path: "my_trigger.save"
|
||||
```
|
||||
```bash
|
||||
touch /path/to/output_dir/my_trigger.save
|
||||
```
|
||||
|
||||
## Control+C (SIGINT) Checkpoint
|
||||
|
||||
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
||||
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
||||
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
output_dir: ./outputs/lora-out
|
||||
save_steps: 500 # Scheduled checkpoints
|
||||
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 50
|
||||
```
|
||||
|
||||
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
||||
@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_model: # model to evaluate (local or hf path)
|
||||
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
|
||||
|
||||
### delinearize-llama4
|
||||
|
||||
|
||||
@@ -32,11 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -74,15 +71,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `0.10.1`
|
||||
- `0.12.0`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
67
docs/expert_quantization.qmd
Normal file
67
docs/expert_quantization.qmd
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: "MoE Expert Quantization"
|
||||
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
|
||||
---
|
||||
|
||||
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
|
||||
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
|
||||
weights being loaded in full bf16 precision and causing massive VRAM usage.
|
||||
|
||||
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
|
||||
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
|
||||
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
|
||||
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
|
||||
|
||||
## Usage
|
||||
|
||||
Enable expert quantization in your Axolotl config:
|
||||
|
||||
```yaml
|
||||
quantize_moe_experts: true
|
||||
```
|
||||
|
||||
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
|
||||
|
||||
### Expert LoRA targeting
|
||||
|
||||
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
|
||||
- CUDA GPUs only (not tested with ROCm or other backends)
|
||||
- FSDP2 compatible for distributed training
|
||||
|
||||
## Limitations
|
||||
|
||||
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
|
||||
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
|
||||
- Total model parameter count may display incorrectly (trainable param count is correct).
|
||||
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
|
||||
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
|
||||
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
|
||||
- DeepSpeed has not been tested.
|
||||
|
||||
## Implementation details
|
||||
|
||||
The quantization is applied by patching transformers to intercept weight loading.
|
||||
When a 3D+ CUDA tensor with "expert" in its name is detected:
|
||||
|
||||
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
|
||||
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
|
||||
|
||||
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
|
||||
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
|
||||
|
||||
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).
|
||||
@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
huggingface-cli login
|
||||
hf auth login
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
@@ -89,6 +89,10 @@ lora_o_kernel: true
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
LoRA kernels do not support remote modeling code.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
|
||||
@@ -13,14 +13,18 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Mistral-Small-4](#sec-mistral-small-4)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [Qwen3.5](#sec-qwen3-5)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
- [Intern-VL](#sec-intern-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -106,6 +110,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
### Mistral-Small-4 {#sec-mistral-small-4}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
```
|
||||
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -182,6 +192,26 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3.5 {#sec-qwen3-5}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
chat_template: qwen3_5
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
```yaml
|
||||
# GLM-4.6V (106B MoE version)
|
||||
base_model: zai-org/GLM-4.6V
|
||||
|
||||
# OR GLM-4.6V-Flash (9B version)
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -202,6 +232,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
### Intern-VL {#sec-intern-vl}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `timm` via `pip3 install timm==1.0.19`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: OpenGVLab/InternVL3_5-8B
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
|
||||
@@ -66,6 +66,15 @@ Provides efficient Triton kernels to improve training speed and reduce memory us
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
### Expert Kernels
|
||||
|
||||
Optimized kernel implementations for Mixture of Experts (MoE) model training.
|
||||
|
||||
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
|
||||
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
@@ -131,3 +140,10 @@ Simulates quantization effects during training, helping the model adapt and pote
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
|
||||
### MoE Expert Quantization
|
||||
|
||||
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
|
||||
|
||||
- **Config:** `quantize_moe_experts: true`
|
||||
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)
|
||||
|
||||
304
docs/rlhf.qmd
304
docs/rlhf.qmd
@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -720,6 +721,309 @@ trl:
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
|
||||
#### Async GRPO
|
||||
|
||||
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true # Enable data producer protocol
|
||||
use_vllm: true
|
||||
async_prefetch: true # Generate rollouts in background thread
|
||||
prefetch_depth: 1 # Number of rollouts to prefetch
|
||||
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
|
||||
:::
|
||||
|
||||
##### vLLM LoRA Sync
|
||||
|
||||
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
vllm_lora_sync: true # Enable native LoRA sync
|
||||
```
|
||||
|
||||
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
```
|
||||
|
||||
Then start training on a separate GPU:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
|
||||
:::
|
||||
|
||||
##### Streaming Partial Batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
```
|
||||
|
||||
##### Importance Sampling Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true # Enable IS correction
|
||||
importance_sampling_level: token # 'token' or 'sequence'
|
||||
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
|
||||
```
|
||||
|
||||
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
|
||||
- `importance_sampling_level: sequence` applies per-sequence IS ratios
|
||||
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
|
||||
|
||||
##### Replay Buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100 # Max cached groups (0 = disabled)
|
||||
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
|
||||
:::
|
||||
|
||||
##### Deferred Re-rolling
|
||||
|
||||
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
##### Zero-Advantage Batch Skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
##### Parallel Reward Workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
|
||||
```
|
||||
|
||||
##### Full Async GRPO Example
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.35
|
||||
dtype: auto
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_data_producer: true
|
||||
use_vllm: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
vllm_lora_sync: true
|
||||
streaming_partial_batch: true
|
||||
vllm_importance_sampling_correction: true
|
||||
off_policy_mask_threshold: 0.5
|
||||
importance_sampling_level: token
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
reroll_start_fraction: 0.5
|
||||
replay_buffer_size: 100
|
||||
reward_num_workers: 4
|
||||
skip_zero_advantage_batches: true
|
||||
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-TIR
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
max_steps: 500
|
||||
learning_rate: 1e-5
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
##### Multi-GPU Async GRPO
|
||||
|
||||
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
|
||||
|
||||
**FSDP:**
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
**DeepSpeed ZeRO-3:**
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for ZeRO-3
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
|
||||
:::
|
||||
|
||||
### GDPO
|
||||
|
||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
||||
|
||||
::: {.callout-tip}
|
||||
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
|
||||
:::
|
||||
|
||||
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
|
||||
|
||||
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
|
||||
rl: gdpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: true
|
||||
num_generations: 4
|
||||
reward_funcs:
|
||||
- rewards.format_reward
|
||||
- rewards.correctness_reward
|
||||
reward_weights: [1.0, 2.0]
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.oai_gsm8k_transform
|
||||
```
|
||||
|
||||
You can also use GRPO with explicit aggregation control:
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
trl:
|
||||
multi_objective_aggregation: normalize_then_sum # GDPO behavior
|
||||
# or: sum_then_normalize # Default GRPO behavior
|
||||
```
|
||||
|
||||
#### GDPO vs GRPO
|
||||
|
||||
| Aspect | GRPO | GDPO |
|
||||
|--------|------|------|
|
||||
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
|
||||
| **Multi-reward** | May collapse advantages | Preserves reward signals |
|
||||
| **Single reward** | Standard behavior | Equivalent to GRPO |
|
||||
|
||||
#### Why GDPO?
|
||||
|
||||
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
|
||||
|
||||
```
|
||||
# Example: format + correctness rewards
|
||||
[format=0, correct=3] → sum=3
|
||||
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
|
||||
[format=2, correct=1] → sum=3
|
||||
[format=3, correct=0] → sum=3
|
||||
```
|
||||
|
||||
GDPO normalizes each reward independently, preserving their relative differences.
|
||||
|
||||
#### Reward Functions
|
||||
|
||||
GDPO uses the same reward function format as GRPO:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
def format_reward(completions, **kwargs) -> list[float]:
|
||||
return [1.0 if len(c) > 10 else 0.0 for c in completions]
|
||||
|
||||
def correctness_reward(completions, answers, **kwargs) -> list[float]:
|
||||
rewards = []
|
||||
for completion, answer in zip(completions, answers):
|
||||
# Your scoring logic here
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
#### Sequence Parallelism
|
||||
|
||||
GDPO supports sequence parallelism for long-context training:
|
||||
|
||||
```yaml
|
||||
rl: gdpo
|
||||
context_parallel_size: 2
|
||||
```
|
||||
|
||||
### SimPO
|
||||
|
||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||
|
||||
90
docs/scripts/examples-allowlist.yml
Normal file
90
docs/scripts/examples-allowlist.yml
Normal file
@@ -0,0 +1,90 @@
|
||||
examples:
|
||||
# December 2025
|
||||
- name: kimi-linear
|
||||
title: Kimi Linear
|
||||
- name: plano
|
||||
title: Plano Orchestrator
|
||||
- name: mimo
|
||||
title: MiMo
|
||||
- name: internvl3_5
|
||||
title: InternVL 3.5
|
||||
|
||||
# AllenAI
|
||||
- name: olmo3
|
||||
title: OLMo 3
|
||||
|
||||
# ArceeAI
|
||||
- name: trinity
|
||||
title: Trinity
|
||||
- name: arcee
|
||||
title: Arcee AFM
|
||||
|
||||
# MistralAI
|
||||
- name: ministral3/think
|
||||
title: Ministral 3 Thinking
|
||||
- name: ministral3/vision
|
||||
title: Ministral 3 Vision
|
||||
- name: magistral/think
|
||||
title: Magistral Thinking
|
||||
- name: magistral/vision
|
||||
title: Magistral Vision
|
||||
- name: ministral
|
||||
title: Ministral
|
||||
- name: mistral-small
|
||||
title: Mistral Small 3.1/3.2
|
||||
- name: voxtral
|
||||
title: Voxtral
|
||||
- name: devstral
|
||||
title: Devstral
|
||||
- name: mistral
|
||||
title: Mistral 7B
|
||||
|
||||
# Meta
|
||||
- name: llama-4
|
||||
title: Llama 4
|
||||
- name: llama-2
|
||||
title: Llama 2
|
||||
|
||||
# Alibaba
|
||||
- name: qwen3-next
|
||||
title: Qwen 3 Next
|
||||
- name: qwen3
|
||||
title: Qwen 3
|
||||
|
||||
# Google
|
||||
- name: gemma3n
|
||||
title: Gemma 3n
|
||||
|
||||
# Swiss AI
|
||||
- name: apertus
|
||||
title: Apertus
|
||||
|
||||
# GPT-OSS
|
||||
- name: gpt-oss
|
||||
title: GPT-OSS
|
||||
- name: seed-oss
|
||||
title: Seed-OSS
|
||||
|
||||
# Microsoft
|
||||
- name: phi
|
||||
title: Phi
|
||||
|
||||
# SmolVLM
|
||||
- name: smolvlm2
|
||||
title: SmolVLM 2
|
||||
|
||||
# IBM
|
||||
- name: granite4
|
||||
title: Granite 4
|
||||
|
||||
# LiquidAI
|
||||
- name: LiquidAI
|
||||
title: Liquid Foundation Models 2
|
||||
|
||||
# Other
|
||||
- name: hunyuan
|
||||
title: Hunyuan
|
||||
- name: jamba
|
||||
title: Jamba
|
||||
- name: orpheus
|
||||
title: Orpheus
|
||||
424
docs/scripts/generate_examples_docs.py
Executable file
424
docs/scripts/generate_examples_docs.py
Executable file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
auto generate example docs from allowlist
|
||||
"""
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# Paths
|
||||
THIS = Path(__file__).resolve()
|
||||
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
|
||||
EXAMPLES_DIR = ROOT / "examples"
|
||||
OUTPUT_DIR = ROOT / "docs" / "models"
|
||||
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
|
||||
|
||||
|
||||
def slugify(name: str) -> str:
|
||||
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
|
||||
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
|
||||
s = re.sub(r"\s+", "-", s).strip("-").lower()
|
||||
return s or "example"
|
||||
|
||||
|
||||
def read_allowlist():
|
||||
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
items = data.get("examples", [])
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("`examples` must be a list in examples-allowlist.yml")
|
||||
return items
|
||||
|
||||
|
||||
def find_readme(folder: Path) -> Path | None:
|
||||
for name in ("README.md", "Readme.md", "readme.md"):
|
||||
p = folder / name
|
||||
if p.exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def remove_first_h1(md: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Remove the first H1 from markdown and return (modified_md, h1_title).
|
||||
The H1 is removed since we use the frontmatter title instead.
|
||||
"""
|
||||
lines = md.splitlines()
|
||||
result = []
|
||||
h1_title = None
|
||||
skipped_first = False
|
||||
|
||||
for line in lines:
|
||||
if not skipped_first and line.startswith("# "):
|
||||
h1_title = line[2:].strip()
|
||||
skipped_first = True
|
||||
continue
|
||||
result.append(line)
|
||||
|
||||
return "\n".join(result), h1_title
|
||||
|
||||
|
||||
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
|
||||
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||
|
||||
|
||||
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
|
||||
"""
|
||||
Copy local image assets referenced in markdown to
|
||||
docs/examples/assets/... and rewrite the links.
|
||||
"""
|
||||
dest_assets = dest_assets_root / "assets"
|
||||
|
||||
def repl(m):
|
||||
url = m.group(1).strip()
|
||||
if re.match(r"^(https?:)?//", url):
|
||||
return m.group(0) # leave remote URLs
|
||||
src_path = (src_dir / url).resolve()
|
||||
if not src_path.exists():
|
||||
return m.group(0) # leave as-is if not found
|
||||
rel = src_path.relative_to(src_dir)
|
||||
# Create a unique asset path based on source directory name
|
||||
asset_name = src_dir.name.replace("/", "-")
|
||||
dest_path = dest_assets / asset_name / rel
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
|
||||
return m.group(0).replace(url, new_rel)
|
||||
|
||||
return IMG_RE.sub(repl, md)
|
||||
|
||||
|
||||
def rewrite_readme_links(
|
||||
md: str,
|
||||
src_dir: Path,
|
||||
examples_dir: Path,
|
||||
parent_index_only: set,
|
||||
current_src_path: str,
|
||||
allowlist_entries: set,
|
||||
current_output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Rewrite links between README.md files to point to the correct .qmd files.
|
||||
"""
|
||||
|
||||
def repl(m):
|
||||
text = m.group(1)
|
||||
url = m.group(2).strip()
|
||||
|
||||
# Skip remote URLs and anchor links
|
||||
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
|
||||
return m.group(0)
|
||||
|
||||
# Skip non-markdown files
|
||||
if not url.lower().endswith(".md"):
|
||||
return m.group(0)
|
||||
|
||||
# Resolve the target path
|
||||
try:
|
||||
target_path = (src_dir / url).resolve()
|
||||
|
||||
# Check if target is outside examples_dir
|
||||
try:
|
||||
rel_path = target_path.relative_to(examples_dir)
|
||||
except ValueError:
|
||||
# Target is outside examples_dir, leave as-is
|
||||
return m.group(0)
|
||||
|
||||
parts = list(rel_path.parts)
|
||||
|
||||
# Determine the output path for the target
|
||||
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
|
||||
# This is a README link
|
||||
if len(parts) == 1:
|
||||
# Link to root README -> index.qmd
|
||||
target_output = "index.qmd"
|
||||
elif len(parts) == 2:
|
||||
if parts[0] == ".":
|
||||
# Current directory README
|
||||
target_output = "index.qmd"
|
||||
else:
|
||||
# subdir/README.md
|
||||
parent_dir = parts[0]
|
||||
if parent_dir in parent_index_only:
|
||||
target_output = f"{parent_dir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{parent_dir}.qmd"
|
||||
else:
|
||||
# Deeper nesting: parent/subdir/README.md
|
||||
# Build the full path like "parent/subdir"
|
||||
full_path = "/".join(parts[:-1]) # Remove README.md
|
||||
# Check if this exact path is in allowlist
|
||||
if full_path in allowlist_entries:
|
||||
# This is a sub-entry with its own entry -> use .qmd
|
||||
target_output = f"{full_path}.qmd"
|
||||
elif parts[0] == ".":
|
||||
# ./subdir/README.md -> check if subdir has own entry
|
||||
subdir = parts[1]
|
||||
if subdir in parent_index_only:
|
||||
target_output = f"{subdir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{subdir}.qmd"
|
||||
else:
|
||||
# parent/subdir where parent doesn't have own entry
|
||||
target_output = f"{full_path}/index.qmd"
|
||||
else:
|
||||
# Regular .md file -> convert to .qmd, keep path structure
|
||||
target_output = "/".join(parts)[:-2] + "qmd"
|
||||
|
||||
# Compute relative path from current output file to target
|
||||
current_parts = current_output_path.split("/")
|
||||
target_parts = target_output.split("/")
|
||||
|
||||
# Special case: if current is a subdir file and target is a single-component file at root
|
||||
# Example: current="magistral/vision", target="magistral.qmd"
|
||||
if len(current_parts) > 1 and len(target_parts) == 1:
|
||||
# Current is in subdir, target is at root level
|
||||
# Go up to root: ../ for each level
|
||||
up_count = len(current_parts) - 1
|
||||
rel_parts = [".."] * up_count + [target_parts[0]]
|
||||
new_url = "/".join(rel_parts)
|
||||
else:
|
||||
# Find common prefix
|
||||
i = 0
|
||||
while (
|
||||
i < min(len(current_parts) - 1, len(target_parts))
|
||||
and current_parts[i] == target_parts[i]
|
||||
):
|
||||
i += 1
|
||||
|
||||
# Build relative path: go up (../) then down to target
|
||||
up_count = len(current_parts) - 1 - i
|
||||
rel_parts = [".."] * up_count + target_parts[i:]
|
||||
|
||||
if not rel_parts or rel_parts == [".."]:
|
||||
# Points to same directory or parent
|
||||
new_url = "/".join(rel_parts) if rel_parts else "."
|
||||
else:
|
||||
new_url = "/".join(rel_parts)
|
||||
|
||||
return f"[{text}]({new_url})"
|
||||
except (ValueError, IndexError):
|
||||
return m.group(0)
|
||||
|
||||
return LINK_RE.sub(repl, md)
|
||||
|
||||
|
||||
def write_qmd(out_path: Path, title: str, body_md: str):
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
out_path.write_text(fm + body_md, encoding="utf-8")
|
||||
|
||||
|
||||
def update_quarto_yml(generated: list[tuple[str, str, str]]):
|
||||
"""
|
||||
Update _quarto.yml with the generated example files in the correct order.
|
||||
This keeps the sidebar in sync with the allowlist.
|
||||
|
||||
Model Guides is now nested under "Getting Started" section.
|
||||
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
|
||||
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
|
||||
"""
|
||||
quarto_yml = ROOT / "_quarto.yml"
|
||||
if not quarto_yml.exists():
|
||||
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
|
||||
return
|
||||
|
||||
content = quarto_yml.read_text(encoding="utf-8")
|
||||
|
||||
# First pass: find all parents that have sub-entries
|
||||
parents_with_subs = set()
|
||||
for path, _name, _title in generated:
|
||||
if "/" in path:
|
||||
parent = path.split("/")[0]
|
||||
parents_with_subs.add(parent)
|
||||
|
||||
# Build the YAML contents while preserving allowlist order
|
||||
lines = []
|
||||
processed_sections = set()
|
||||
|
||||
for path, _name, title in generated:
|
||||
# Check if this is a parent page that has sub-pages
|
||||
if path in parents_with_subs:
|
||||
# This is a parent page with sub-pages - create a nested section
|
||||
if path not in processed_sections:
|
||||
processed_sections.add(path)
|
||||
section_title = (
|
||||
title or path.replace("-", " ").replace("_", " ").title()
|
||||
)
|
||||
lines.append(f' - section: "{section_title}"')
|
||||
lines.append(" contents:")
|
||||
# Add the parent page first
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
# Then add all sub-pages
|
||||
for sub_path, _sub_name, _sub_title in generated:
|
||||
if "/" in sub_path and sub_path.split("/")[0] == path:
|
||||
lines.append(
|
||||
f" - docs/models/{sub_path}.qmd"
|
||||
)
|
||||
elif "/" not in path:
|
||||
# This is a flat item with no sub-pages
|
||||
# Skip if it was already included as part of a parent section
|
||||
if path not in processed_sections:
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
|
||||
yaml_content = "\n".join(lines) + "\n"
|
||||
|
||||
# Pattern to match only the Model Guides contents, stopping at the next item
|
||||
# in Getting Started (lines starting with 12 spaces: same level as the section)
|
||||
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
|
||||
|
||||
def replacement(match):
|
||||
prefix = match.group(1)
|
||||
return prefix + "\n" + yaml_content
|
||||
|
||||
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
if new_content != content:
|
||||
quarto_yml.write_text(new_content, encoding="utf-8")
|
||||
print(f"Updated {quarto_yml}")
|
||||
else:
|
||||
print(f"No changes needed for {quarto_yml}")
|
||||
|
||||
|
||||
def main():
|
||||
allow = read_allowlist()
|
||||
if not EXAMPLES_DIR.exists():
|
||||
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
|
||||
return
|
||||
|
||||
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# First pass: identify which parents have their own entry vs only sub-entries
|
||||
parent_entries = set() # Parents that have their own entry
|
||||
parent_with_subs = set() # Parents that have sub-entries
|
||||
allowlist_entries = set() # All entries in allowlist
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
else:
|
||||
name = item.get("name")
|
||||
|
||||
allowlist_entries.add(name)
|
||||
|
||||
if "/" in name:
|
||||
parent = name.split("/")[0]
|
||||
parent_with_subs.add(parent)
|
||||
else:
|
||||
parent_entries.add(name)
|
||||
|
||||
# Parents with subs that DON'T have their own entry -> use index.qmd
|
||||
parent_index_only = parent_with_subs - parent_entries
|
||||
|
||||
generated = []
|
||||
seen_dirs = set() # Track which parent directories we've created index for
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
title = None
|
||||
else:
|
||||
name = item.get("name")
|
||||
title = item.get("title")
|
||||
|
||||
if not name:
|
||||
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
src_dir = EXAMPLES_DIR / name
|
||||
if not src_dir.exists() or not src_dir.is_dir():
|
||||
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
readme = find_readme(src_dir)
|
||||
if not readme:
|
||||
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
md = readme.read_text(encoding="utf-8")
|
||||
|
||||
# Determine output path first (needed for link rewriting)
|
||||
parts = name.split("/")
|
||||
if len(parts) == 1:
|
||||
# Simple case: no subdirectory
|
||||
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
|
||||
sidebar_path = parts[0]
|
||||
else:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
child = "-".join(parts[1:]) # handle nested subdirs
|
||||
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
|
||||
sidebar_path = f"{parent}/{child}"
|
||||
|
||||
# Remove the first H1 (we use frontmatter title instead)
|
||||
md, _ = remove_first_h1(md)
|
||||
# Rewrite links between README files
|
||||
md = rewrite_readme_links(
|
||||
md,
|
||||
src_dir,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
name,
|
||||
allowlist_entries,
|
||||
sidebar_path,
|
||||
)
|
||||
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
|
||||
|
||||
# Handle parent page generation for sub-entries
|
||||
if len(parts) > 1:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
|
||||
# Create parent.qmd if not already done and parent doesn't have own entry
|
||||
if parent not in seen_dirs and parent in parent_index_only:
|
||||
parent_readme = find_readme(EXAMPLES_DIR / parent)
|
||||
if parent_readme:
|
||||
parent_md = parent_readme.read_text(encoding="utf-8")
|
||||
parent_md, _ = remove_first_h1(parent_md)
|
||||
parent_md = rewrite_readme_links(
|
||||
parent_md,
|
||||
EXAMPLES_DIR / parent,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
parent,
|
||||
allowlist_entries,
|
||||
parent,
|
||||
)
|
||||
parent_md = rewrite_and_copy_assets(
|
||||
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
|
||||
)
|
||||
parent_title = parent.replace("-", " ").replace("_", " ").title()
|
||||
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
|
||||
generated.append((parent, parent, parent_title))
|
||||
seen_dirs.add(parent)
|
||||
|
||||
if not title:
|
||||
title = name.replace("/", " ").replace("-", " ").title()
|
||||
|
||||
write_qmd(out_path, title, md)
|
||||
generated.append((sidebar_path, name, title))
|
||||
|
||||
# Index page - preserve allowlist order
|
||||
if generated:
|
||||
listing = "\n".join(
|
||||
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
|
||||
)
|
||||
index_md = (
|
||||
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
|
||||
+ listing
|
||||
+ "\n"
|
||||
)
|
||||
index_fm = (
|
||||
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
)
|
||||
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
|
||||
|
||||
# Auto-update _quarto.yml to keep sidebar in sync
|
||||
update_quarto_yml(generated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
77
examples/eaft/eaft-example.yml
Normal file
77
examples/eaft/eaft-example.yml
Normal file
@@ -0,0 +1,77 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/eaft-gemma-3-1b
|
||||
|
||||
use_eaft: true
|
||||
eaft_alpha: 1.0
|
||||
eaft_k: 20
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
eval_batch_size: 1
|
||||
max_steps: 1000
|
||||
evaluation_strategy: "no"
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
debug:
|
||||
deepspeed:
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,7 +1,5 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -26,10 +24,15 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -26,10 +24,15 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
@@ -23,6 +20,11 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -32,8 +34,8 @@ sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
72
examples/glm45/README.md
Normal file
72
examples/glm45/README.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# Finetune Z.ai's GLM-4.5-Air with Axolotl
|
||||
|
||||
[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA (1x80GB @ ~63.4GiB/GPU)
|
||||
axolotl train examples/glm45/glm-45-air-qlora.yaml
|
||||
```
|
||||
|
||||
### Dataset
|
||||
|
||||
In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "...", // or have </think>...</think> in `content`
|
||||
"content": "..."
|
||||
}
|
||||
```
|
||||
|
||||
Make sure you set the below extra attributes if needed:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
# tool_calls: tool_calls # uncomment if using tools
|
||||
# reasoning_content: reasoning_content # uncomment if have reasoning
|
||||
|
||||
# Uncomment if training on tool role (you would rarely if ever need this)
|
||||
# eot_tokens:
|
||||
# - <|observation|>
|
||||
```
|
||||
|
||||
### Tips
|
||||
|
||||
- The role name for tools in this template is `tool`.
|
||||
- You will see this Axolotl WARNING — this is expected as the template does not use EOS:
|
||||
```
|
||||
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
|
||||
```
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.
|
||||
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)
|
||||
- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/glm45/glm-45-air-qlora.yaml
Normal file
64
examples/glm45/glm-45-air-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: zai-org/GLM-4.5-Air
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
quantize_moe_experts: true # important
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
44
examples/glm46v/README.md
Normal file
44
examples/glm46v/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Finetune GLM-4.6V with Axolotl
|
||||
|
||||
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
|
||||
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
|
||||
3. Run the fine-tuning:
|
||||
|
||||
glm-4-6v-flash(9B)
|
||||
```bash
|
||||
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
## Tips
|
||||
|
||||
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
|
||||
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Supported Models
|
||||
|
||||
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
|
||||
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
53
examples/glm46v/glm-4-6v-flash-ddp.yaml
Normal file
53
examples/glm46v/glm-4-6v-flash-ddp.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
trust_remote_code: true
|
||||
|
||||
processor_type: AutoProcessor
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
50
examples/glm46v/glm-4-6v-flash-qlora.yaml
Normal file
50
examples/glm46v/glm-4-6v-flash-qlora.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
trust_remote_code: true
|
||||
|
||||
processor_type: AutoProcessor
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
65
examples/glm47-flash/README.md
Normal file
65
examples/glm47-flash/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||
|
||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA
|
||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||
axolotl train examples/glm47-flash/qlora.yaml
|
||||
|
||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||
axolotl train examples/glm47-flash/qlora_fsdp.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
# LoRA
|
||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||
axolotl train examples/glm47-flash/lora.yaml
|
||||
|
||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||
axolotl train examples/glm47-flash/lora_fsdp.yaml
|
||||
```
|
||||
|
||||
### MoE Expert Quantization & Expert LoRA
|
||||
|
||||
This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **lora_target_linear**: Incompatible for this model.
|
||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||
- `temperature: 1.0`
|
||||
- `top_p: 0.95`
|
||||
- `max_new_tokens: 131072`
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
65
examples/glm47-flash/lora.yaml
Normal file
65
examples/glm47-flash/lora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm47-flash/lora_fsdp.yaml
Normal file
75
examples/glm47-flash/lora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
65
examples/glm47-flash/qlora.yaml
Normal file
65
examples/glm47-flash/qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm47-flash/qlora_fsdp.yaml
Normal file
75
examples/glm47-flash/qlora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -32,6 +32,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -29,6 +29,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
43
examples/internvl3_5/README.md
Normal file
43
examples/internvl3_5/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Finetune OpenGV's InternVL with Axolotl
|
||||
|
||||
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install `timm` for vision model support:
|
||||
|
||||
```bash
|
||||
pip install timm==1.0.19
|
||||
```
|
||||
|
||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: OpenGVLab/InternVL3_5-8B-HF
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -19,7 +19,6 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
save_safetensors: true
|
||||
adapter: qlora
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
47
examples/kimi-linear/README.md
Normal file
47
examples/kimi-linear/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
||||
|
||||
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
||||
```
|
||||
|
||||
This config uses about 98.7GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning!
|
||||
|
||||
### TIPS
|
||||
|
||||
- Kimi Linear requires `trust_remote_code: true`.
|
||||
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
This is not yet compatible with MoE kernels from transformers v5.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
||||
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.2
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -29,7 +29,6 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
split: train[:95%]
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
dataset_prepared_path: ./outputs/dataset_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: mxfp4
|
||||
weight_dtype: mxfp4
|
||||
group_size: 32
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
|
||||
cosine_constant_lr_ratio: 0
|
||||
cosine_min_lr_ratio: 1.0
|
||||
learning_rate: 2e-5
|
||||
save_only_model: true
|
||||
bf16: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
chat_template: llama3
|
||||
|
||||
rl: gdpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 128
|
||||
num_generations: 2
|
||||
temperature: 0.7
|
||||
top_p: 0.95
|
||||
|
||||
use_vllm: false
|
||||
|
||||
|
||||
multi_objective_aggregation: normalize_then_sum
|
||||
|
||||
reward_funcs:
|
||||
- rwd.format_reward
|
||||
- rwd.correctness_reward
|
||||
reward_weights: [1.0, 2.0]
|
||||
|
||||
log_completions: true
|
||||
num_completions_to_print: 3
|
||||
scale_rewards: true
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
split: train[:1000]
|
||||
type: rwd.gsm8k_transform
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/llama3-gdpo-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
flash_attention: true
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
save_safetensors: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
|
||||
seed: 42
|
||||
@@ -12,7 +12,6 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||
save_safetensors: true
|
||||
|
||||
adapter: qlora
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -5,7 +5,8 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
|
||||
## Getting started
|
||||
|
||||
|
||||
@@ -47,6 +47,5 @@ saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
39
examples/mimo/README.md
Normal file
39
examples/mimo/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Finetune Xiaomi's MiMo with Axolotl
|
||||
|
||||
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/mimo/mimo-7b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: XiaomiMiMo/MiMo-7B-RL
|
||||
trust_remote_code: true
|
||||
revision_of_model: 6299b5a
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -59,6 +59,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -5,7 +5,8 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
|
||||
## Getting started
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
85
examples/mistral4/README.md
Normal file
85
examples/mistral4/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# Finetune Mistral Small 4 with Axolotl
|
||||
|
||||
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
Note: Training this model requires weights in BF16 which we will link to later.
|
||||
Users interested in training can convert / descale the existing FP8 weights.
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
3. Install transformers from main
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
```bash
|
||||
# text-only
|
||||
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
|
||||
axolotl train examples/mistral4/fft-text.yml
|
||||
|
||||
# text + vision
|
||||
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
|
||||
axolotl train examples/mistral4/fft-vision.yml
|
||||
```
|
||||
|
||||
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
|
||||
|
||||
## Reasoning Effort
|
||||
|
||||
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
|
||||
|
||||
- `"none"` — instruct mode (default)
|
||||
- `"high"` — reasoning mode with explicit thinking steps
|
||||
|
||||
Pass it via `chat_template_kwargs` under your dataset config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: chat_template
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
## Thinking Support
|
||||
|
||||
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
|
||||
|
||||
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/thinking-dataset
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
thinking: thinking
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
|
||||
|
||||
## Tips
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
58
examples/mistral4/fft-text.yml
Normal file
58
examples/mistral4/fft-text.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_sonicmoe: true
|
||||
|
||||
# only train language model layers, freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- model.language_model.*
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user