Compare commits
74 Commits
upgrade-to
...
uv-fixup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4a3b618e7 | ||
|
|
b6b8db805a | ||
|
|
653f90be25 | ||
|
|
945c8aeb10 | ||
|
|
e672d37f33 | ||
|
|
77828d3559 | ||
|
|
4272817109 | ||
|
|
474208b794 | ||
|
|
444020b332 | ||
|
|
aa88c2e30b | ||
|
|
f447bce1db | ||
|
|
7f23b302d1 | ||
|
|
18f26c19ef | ||
|
|
2b6f4a6c9b | ||
|
|
8f54b4eb25 | ||
|
|
a131e4d0e5 | ||
|
|
1791d87b6f | ||
|
|
b40803da51 | ||
|
|
68f1b7004c | ||
|
|
08441fed17 | ||
|
|
86ca1e27c0 | ||
|
|
5ed455715e | ||
|
|
3f30572d4a | ||
|
|
43d60c7439 | ||
|
|
0ea252d392 | ||
|
|
29722dec60 | ||
|
|
7fbedbd300 | ||
|
|
145ffc9be1 | ||
|
|
4f1b5ad29f | ||
|
|
d6a2532dd7 | ||
|
|
5eb265513c | ||
|
|
06ac407b92 | ||
|
|
4e22cf0651 | ||
|
|
a4ee56c315 | ||
|
|
c67cbcb0f5 | ||
|
|
a2da852576 | ||
|
|
37e9da7a53 | ||
|
|
ed7105dba7 | ||
|
|
b6d3653f74 | ||
|
|
fcc4cfdb63 | ||
|
|
97a4f28511 | ||
|
|
86a5803212 | ||
|
|
530a0c0bf0 | ||
|
|
0343a72cc9 | ||
|
|
236dad3bb7 | ||
|
|
be00978bc2 | ||
|
|
3738978394 | ||
|
|
6132a30cda | ||
|
|
3dd86d35b8 | ||
|
|
dd9ebaeba1 | ||
|
|
fc4e37920b | ||
|
|
a531e9d946 | ||
|
|
04328aeb97 | ||
|
|
d0d26d5064 | ||
|
|
8623dd8a72 | ||
|
|
8cd75cff9f | ||
|
|
8ab9d9ea88 | ||
|
|
6e42def14b | ||
|
|
c413480b35 | ||
|
|
8f25124269 | ||
|
|
790df757cb | ||
|
|
d282f32481 | ||
|
|
6331e4a130 | ||
|
|
1410e4474e | ||
|
|
dc77b5bf42 | ||
|
|
359b7ad85e | ||
|
|
258ce8d4fa | ||
|
|
3e0bbd33ec | ||
|
|
4ae6f766ad | ||
|
|
e7f0d4ba5b | ||
|
|
7bf6f70e96 | ||
|
|
8aab807e67 | ||
|
|
ee59e4de97 | ||
|
|
4e61b8aa23 |
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,6 +15,11 @@
|
||||
<!--- Include details of your testing environment, tests ran to see how -->
|
||||
<!--- your change affects other areas of the code, etc. -->
|
||||
|
||||
## AI Usage Disclaimer
|
||||
|
||||
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
||||
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
||||
|
||||
## Screenshots (if appropriate)
|
||||
|
||||
## Types of changes
|
||||
|
||||
96
.github/workflows/base.yml
vendored
96
.github/workflows/base.yml
vendored
@@ -21,6 +21,8 @@ jobs:
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -32,6 +34,7 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -39,6 +42,7 @@ jobs:
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -46,6 +50,31 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -53,6 +82,23 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -79,6 +125,7 @@ jobs:
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -89,6 +136,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
@@ -103,6 +151,8 @@ jobs:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -114,6 +164,7 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -121,6 +172,7 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -128,6 +180,31 @@ jobs:
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-uv-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -135,6 +212,23 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -146,6 +240,7 @@ jobs:
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -156,6 +251,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
209
.github/workflows/main.yml
vendored
209
.github/workflows/main.yml
vendored
@@ -20,22 +20,44 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -61,6 +83,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -75,6 +98,77 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-uv:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||
- name: Build and export to Docker
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -87,22 +181,44 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -127,6 +243,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -137,6 +254,73 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-uv:
|
||||
needs: build-axolotl-uv
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-cloud-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -147,11 +331,11 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
is_latest: true
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
@@ -180,6 +364,7 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
|
||||
20
.github/workflows/multi-gpu-e2e.yml
vendored
20
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -35,14 +35,26 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: fbgemm-gpu
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -64,8 +76,8 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run -m cicd.multigpu
|
||||
|
||||
6
.github/workflows/pypi.yml
vendored
6
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==23.2
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
@@ -48,9 +48,9 @@ jobs:
|
||||
id: tag
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
|
||||
- name: Update version in setup.py
|
||||
- name: Update version in VERSION file
|
||||
run: |
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
|
||||
|
||||
- name: Build a source dist
|
||||
run: |
|
||||
|
||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
|
||||
64
.github/workflows/tests.yml
vendored
64
.github/workflows/tests.yml
vendored
@@ -54,8 +54,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -70,7 +75,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
@@ -82,7 +87,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -110,10 +115,10 @@ jobs:
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: hf cache ls
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -127,7 +132,7 @@ jobs:
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: hf cache ls
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
@@ -144,8 +149,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -160,7 +170,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
@@ -172,7 +182,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -200,7 +210,7 @@ jobs:
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: hf cache ls
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -209,10 +219,10 @@ jobs:
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: hf cache ls
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
needs: [pre-commit]
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
skip: ${{ steps.compute.outputs.skip }}
|
||||
@@ -248,16 +258,16 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
@@ -316,6 +326,18 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -353,8 +375,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
|
||||
@@ -123,7 +123,7 @@ datasets:
|
||||
| --------------------------------- | -------------------------- | ----------------------------------- |
|
||||
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
||||
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
||||
| `dataset_processes` | `4` | Number of preprocessing processes |
|
||||
| `dataset_num_proc` | `4` | Number of preprocessing processes |
|
||||
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
||||
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
||||
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
||||
|
||||
@@ -39,7 +39,6 @@
|
||||
# type: # linear | dynamic
|
||||
# factor: # float
|
||||
|
||||
|
||||
# # Whether you are training a 4-bit GPTQ quantized model
|
||||
# gptq: true
|
||||
# gptq_groupsize: 128 # group size
|
||||
@@ -107,7 +106,7 @@
|
||||
# push_dataset_to_hub: # repo path
|
||||
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# # if not set.
|
||||
# dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# dataset_num_proc: # defaults to os.cpu_count() if not set
|
||||
# # push checkpoints to hub
|
||||
# hub_model_id: # repo path to push finetuned model
|
||||
# # how to push checkpoints to hub
|
||||
@@ -224,9 +223,6 @@
|
||||
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# # Save model as safetensors (require safetensors package)
|
||||
# save_safetensors:
|
||||
|
||||
# # Whether to mask out or include the human's prompt from the training labels
|
||||
# train_on_inputs: false
|
||||
# # Group similarly sized data to minimize padding.
|
||||
@@ -352,8 +348,6 @@
|
||||
# # Allow overwrite yml config using from cli
|
||||
# strict:
|
||||
|
||||
|
||||
|
||||
base_model: ${BASE_MODEL}
|
||||
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
|
||||
base_model_config: ${BASE_MODEL_CONFIG}
|
||||
@@ -412,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
|
||||
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
|
||||
dataset_prepared_path: ${DATASET_PREPARED_PATH}
|
||||
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
|
||||
dataset_processes: ${DATASET_PROCESSES}
|
||||
dataset_num_proc: ${DATASET_NUM_PROC}
|
||||
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
|
||||
hub_model_id: ${HUB_MODEL_ID}
|
||||
hub_strategy: ${HUB_STRATEGY}
|
||||
@@ -512,7 +506,6 @@ profiler_steps: ${PROFILER_STEPS}
|
||||
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
|
||||
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
|
||||
|
||||
save_safetensors: ${SAVE_SAFETENSORS}
|
||||
train_on_inputs: ${TRAIN_ON_INPUTS}
|
||||
group_by_length: ${GROUP_BY_LENGTH}
|
||||
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
|
||||
|
||||
@@ -88,7 +88,7 @@ Features:
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
|
||||
@@ -251,7 +251,6 @@ website:
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
@@ -266,6 +265,7 @@ website:
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
@@ -320,6 +320,7 @@ website:
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
- docs/attention.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
|
||||
@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -17,7 +17,8 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
@@ -27,8 +28,11 @@ df_args = {
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v --durations=10 -n2 --maxfail=4 \
|
||||
pytest -v --durations=10 -n2 --maxfail=3 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py | sh && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="118"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
@@ -22,11 +24,17 @@ RUN apt-get update \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MINICONDA_ARCH="x86_64"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
MINICONDA_ARCH="aarch64"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi \
|
||||
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
@@ -35,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
@@ -51,8 +59,18 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
|
||||
30
docker/Dockerfile-cloud-uv
Normal file
30
docker/Dockerfile-cloud-uv
Normal file
@@ -0,0 +1,30 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl-uv:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN uv pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
47
docker/Dockerfile-uv
Normal file
47
docker/Dockerfile-uv
Normal file
@@ -0,0 +1,47 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
# fix so that git fetch/pull from remote works with shallow clone
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch && \
|
||||
git config --global credential.helper store
|
||||
|
||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||
@@ -2,9 +2,11 @@ ARG CUDA_VERSION="12.6.3"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.6.0"
|
||||
ARG CUDA="126"
|
||||
@@ -31,12 +33,25 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
|
||||
Download a base model using the Hugging Face CLI:
|
||||
|
||||
```bash
|
||||
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
```
|
||||
|
||||
### 10. Create Axolotl Configuration
|
||||
|
||||
140
docs/attention.qmd
Normal file
140
docs/attention.qmd
Normal file
@@ -0,0 +1,140 @@
|
||||
---
|
||||
title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention 2
|
||||
|
||||
Uses efficient kernels to compute attention.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Nvidia
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
#### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
No sample packing support!
|
||||
|
||||
:::
|
||||
@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_model: # model to evaluate (local or hf path)
|
||||
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
|
||||
|
||||
### delinearize-llama4
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
huggingface-cli login
|
||||
hf auth login
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
@@ -89,6 +89,10 @@ lora_o_kernel: true
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
LoRA kernels do not support remote modeling code.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
|
||||
@@ -19,6 +19,7 @@ format:
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
- [Intern-VL](#sec-intern-vl)
|
||||
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
```yaml
|
||||
# GLM-4.6V (106B MoE version)
|
||||
base_model: zai-org/GLM-4.6V
|
||||
|
||||
# OR GLM-4.6V-Flash (9B version)
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -720,6 +721,102 @@ trl:
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
|
||||
### GDPO
|
||||
|
||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
||||
|
||||
::: {.callout-tip}
|
||||
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
|
||||
:::
|
||||
|
||||
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
|
||||
|
||||
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
|
||||
rl: gdpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: true
|
||||
num_generations: 4
|
||||
reward_funcs:
|
||||
- rewards.format_reward
|
||||
- rewards.correctness_reward
|
||||
reward_weights: [1.0, 2.0]
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.oai_gsm8k_transform
|
||||
```
|
||||
|
||||
You can also use GRPO with explicit aggregation control:
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
trl:
|
||||
multi_objective_aggregation: normalize_then_sum # GDPO behavior
|
||||
# or: sum_then_normalize # Default GRPO behavior
|
||||
```
|
||||
|
||||
#### GDPO vs GRPO
|
||||
|
||||
| Aspect | GRPO | GDPO |
|
||||
|--------|------|------|
|
||||
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
|
||||
| **Multi-reward** | May collapse advantages | Preserves reward signals |
|
||||
| **Single reward** | Standard behavior | Equivalent to GRPO |
|
||||
|
||||
#### Why GDPO?
|
||||
|
||||
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
|
||||
|
||||
```
|
||||
# Example: format + correctness rewards
|
||||
[format=0, correct=3] → sum=3
|
||||
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
|
||||
[format=2, correct=1] → sum=3
|
||||
[format=3, correct=0] → sum=3
|
||||
```
|
||||
|
||||
GDPO normalizes each reward independently, preserving their relative differences.
|
||||
|
||||
#### Reward Functions
|
||||
|
||||
GDPO uses the same reward function format as GRPO:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
def format_reward(completions, **kwargs) -> list[float]:
|
||||
return [1.0 if len(c) > 10 else 0.0 for c in completions]
|
||||
|
||||
def correctness_reward(completions, answers, **kwargs) -> list[float]:
|
||||
rewards = []
|
||||
for completion, answer in zip(completions, answers):
|
||||
# Your scoring logic here
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
#### Sequence Parallelism
|
||||
|
||||
GDPO supports sequence parallelism for long-context training:
|
||||
|
||||
```yaml
|
||||
rl: gdpo
|
||||
context_parallel_size: 2
|
||||
```
|
||||
|
||||
### SimPO
|
||||
|
||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
77
examples/eaft/eaft-example.yml
Normal file
77
examples/eaft/eaft-example.yml
Normal file
@@ -0,0 +1,77 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/eaft-gemma-3-1b
|
||||
|
||||
use_eaft: true
|
||||
eaft_alpha: 1.0
|
||||
eaft_k: 20
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
eval_batch_size: 1
|
||||
max_steps: 1000
|
||||
evaluation_strategy: "no"
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
debug:
|
||||
deepspeed:
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -32,8 +33,8 @@ sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
77
examples/glm4.7-flash/README.md
Normal file
77
examples/glm4.7-flash/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||
|
||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA
|
||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/qlora.yaml
|
||||
|
||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
# LoRA
|
||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/lora.yaml
|
||||
|
||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
|
||||
```
|
||||
|
||||
### Expert LoRA
|
||||
|
||||
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
|
||||
|
||||
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router, untested but should work, not normally targeted
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
|
||||
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
|
||||
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
|
||||
- **lora_target_linear**: Incompatible for this model.
|
||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||
- `temperature: 1.0`
|
||||
- `top_p: 0.95`
|
||||
- `max_new_tokens: 131072`
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
65
examples/glm4.7-flash/lora.yaml
Normal file
65
examples/glm4.7-flash/lora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
65
examples/glm4.7-flash/qlora.yaml
Normal file
65
examples/glm4.7-flash/qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
44
examples/glm46v/README.md
Normal file
44
examples/glm46v/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Finetune GLM-4.6V with Axolotl
|
||||
|
||||
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
|
||||
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
|
||||
3. Run the fine-tuning:
|
||||
|
||||
glm-4-6v-flash(9B)
|
||||
```bash
|
||||
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
## Tips
|
||||
|
||||
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
|
||||
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Supported Models
|
||||
|
||||
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
|
||||
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
53
examples/glm46v/glm-4-6v-flash-ddp.yaml
Normal file
53
examples/glm46v/glm-4-6v-flash-ddp.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
trust_remote_code: true
|
||||
|
||||
processor_type: AutoProcessor
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
50
examples/glm46v/glm-4-6v-flash-qlora.yaml
Normal file
50
examples/glm46v/glm-4-6v-flash-qlora.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: zai-org/GLM-4.6V-Flash
|
||||
trust_remote_code: true
|
||||
|
||||
processor_type: AutoProcessor
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
|
||||
@@ -19,7 +19,6 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
save_safetensors: true
|
||||
adapter: qlora
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
68
examples/llama-3/qlora-1b-gdpo.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
chat_template: llama3
|
||||
|
||||
rl: gdpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 128
|
||||
num_generations: 2
|
||||
temperature: 0.7
|
||||
top_p: 0.95
|
||||
|
||||
use_vllm: false
|
||||
|
||||
|
||||
multi_objective_aggregation: normalize_then_sum
|
||||
|
||||
reward_funcs:
|
||||
- rwd.format_reward
|
||||
- rwd.correctness_reward
|
||||
reward_weights: [1.0, 2.0]
|
||||
|
||||
log_completions: true
|
||||
num_completions_to_print: 3
|
||||
scale_rewards: true
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
split: train[:1000]
|
||||
type: rwd.gsm8k_transform
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/llama3-gdpo-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
max_steps: 100
|
||||
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
flash_attention: true
|
||||
logging_steps: 1
|
||||
save_steps: 50
|
||||
save_safetensors: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
|
||||
seed: 42
|
||||
@@ -12,7 +12,6 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||
save_safetensors: true
|
||||
|
||||
adapter: qlora
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -47,6 +47,5 @@ saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -59,6 +59,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ plugins:
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
@@ -25,7 +27,7 @@ sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
@@ -34,12 +36,19 @@ lora_target_modules:
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
|
||||
285
examples/swanlab/README.md
Normal file
285
examples/swanlab/README.md
Normal file
@@ -0,0 +1,285 @@
|
||||
# SwanLab Integration Examples
|
||||
|
||||
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
|
||||
|
||||
## Examples Overview
|
||||
|
||||
### 1. DPO with Completion Logging
|
||||
**File**: `dpo-swanlab-completions.yml`
|
||||
|
||||
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
|
||||
|
||||
**Features**:
|
||||
- Basic SwanLab experiment tracking
|
||||
- Completion table logging (prompts, chosen/rejected responses, rewards)
|
||||
- Memory-bounded buffer for long training runs
|
||||
- Cloud sync configuration
|
||||
|
||||
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. LoRA with Performance Profiling
|
||||
**File**: `lora-swanlab-profiling.yml`
|
||||
|
||||
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
|
||||
|
||||
**Features**:
|
||||
- SwanLab experiment tracking
|
||||
- Automatic profiling of trainer methods
|
||||
- Profiling metrics visualization
|
||||
- Performance optimization guidance
|
||||
|
||||
**Best for**: Engineers optimizing training performance and comparing different configurations
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. Full-Featured DPO Production Setup
|
||||
**File**: `dpo-swanlab-full-featured.yml`
|
||||
|
||||
Comprehensive production-ready configuration with ALL SwanLab features enabled.
|
||||
|
||||
**Features**:
|
||||
- Experiment tracking with team workspace
|
||||
- RLHF completion logging
|
||||
- Performance profiling
|
||||
- Lark (Feishu) team notifications
|
||||
- Private deployment support
|
||||
- Production checklist and troubleshooting
|
||||
|
||||
**Best for**: Production RLHF training with team collaboration
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. Custom Trainer Profiling (Python)
|
||||
**File**: `custom_trainer_profiling.py`
|
||||
|
||||
Python code examples showing how to add SwanLab profiling to custom trainers.
|
||||
|
||||
**Features**:
|
||||
- `@swanlab_profile` decorator examples
|
||||
- Context manager profiling for fine-grained timing
|
||||
- `ProfilingConfig` for advanced filtering and throttling
|
||||
- Multiple profiling patterns and best practices
|
||||
|
||||
**Best for**: Advanced users creating custom trainers
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from custom_trainer_profiling import CustomTrainerWithProfiling
|
||||
# See file for detailed examples and patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Feature Matrix
|
||||
|
||||
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|
||||
|---------|----------|-------------------|-----------|-------------------|----------------|
|
||||
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||
| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
|
||||
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Quick Reference
|
||||
|
||||
### Basic SwanLab Setup
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-project
|
||||
swanlab_experiment_name: my-experiment
|
||||
swanlab_mode: cloud # cloud, local, offline, disabled
|
||||
```
|
||||
|
||||
### RLHF Completion Logging
|
||||
```yaml
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
|
||||
```
|
||||
|
||||
### Lark Team Notifications
|
||||
```yaml
|
||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||
swanlab_lark_secret: your-webhook-secret # Required for production
|
||||
```
|
||||
|
||||
### Team Workspace
|
||||
```yaml
|
||||
swanlab_workspace: my-research-team
|
||||
```
|
||||
|
||||
### Private Deployment
|
||||
```yaml
|
||||
swanlab_web_host: https://swanlab.yourcompany.com
|
||||
swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
### Recommended: Environment Variable
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
```
|
||||
|
||||
### Alternative: Config File (less secure)
|
||||
```yaml
|
||||
swanlab_api_key: your-api-key
|
||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||
swanlab_lark_secret: your-webhook-secret
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Use Case 1: Migrate from WandB to SwanLab
|
||||
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
|
||||
```yaml
|
||||
use_swanlab: true
|
||||
use_wandb: false
|
||||
```
|
||||
|
||||
### Use Case 2: Analyze DPO Model Outputs
|
||||
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
|
||||
```yaml
|
||||
swanlab_completion_log_interval: 50 # More frequent for short training
|
||||
swanlab_completion_log_interval: 200 # Less frequent for long training
|
||||
```
|
||||
|
||||
### Use Case 3: Optimize Training Performance
|
||||
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
|
||||
- Baseline: `flash_attention: false, gradient_checkpointing: false`
|
||||
- Flash Attention: `flash_attention: true`
|
||||
- Gradient Checkpointing: `gradient_checkpointing: true`
|
||||
- Both: `flash_attention: true, gradient_checkpointing: true`
|
||||
|
||||
Compare profiling metrics in SwanLab dashboard.
|
||||
|
||||
### Use Case 4: Production RLHF with Team Collaboration
|
||||
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
|
||||
```yaml
|
||||
swanlab_workspace: ml-team
|
||||
swanlab_lark_webhook_url: ...
|
||||
swanlab_lark_secret: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Viewing Your Experiments
|
||||
|
||||
### Cloud Mode
|
||||
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
|
||||
|
||||
**Dashboard sections**:
|
||||
- **Metrics**: Training loss, learning rate, profiling metrics
|
||||
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
|
||||
- **Config**: Hyperparameters and configuration
|
||||
- **System**: Resource usage (GPU, memory, CPU)
|
||||
- **Files**: Logged artifacts
|
||||
|
||||
### Local Mode
|
||||
```bash
|
||||
swanlab watch ./swanlog
|
||||
# Open browser to http://localhost:5092
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### SwanLab not initializing
|
||||
```bash
|
||||
# Check API key
|
||||
echo $SWANLAB_API_KEY
|
||||
|
||||
# Verify SwanLab is installed
|
||||
pip show swanlab
|
||||
|
||||
# Check config
|
||||
grep -A 5 "use_swanlab" your-config.yml
|
||||
```
|
||||
|
||||
### Completions not appearing
|
||||
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||
- Check `swanlab_log_completions: true`
|
||||
- Wait for `swanlab_completion_log_interval` steps
|
||||
- Look for "Registered SwanLab RLHF completion logging" in logs
|
||||
|
||||
### Lark notifications not working
|
||||
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
|
||||
- Verify `SWANLAB_LARK_SECRET` is set correctly
|
||||
- Check bot is added to Lark group chat
|
||||
- Look for "Registered Lark notification callback" in logs
|
||||
|
||||
### Profiling metrics not appearing
|
||||
- Verify `use_swanlab: true`
|
||||
- Check SwanLab is initialized (look for init log message)
|
||||
- Profiling metrics are under "profiling/" namespace
|
||||
- Profiling auto-enabled when SwanLab is enabled
|
||||
|
||||
---
|
||||
|
||||
## Performance Notes
|
||||
|
||||
### Overhead Comparison
|
||||
|
||||
| Feature | Overhead per Step | Memory Usage |
|
||||
|---------|------------------|--------------|
|
||||
| Basic tracking | < 0.1% | ~10 MB |
|
||||
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
|
||||
| Profiling | < 0.1% | ~1 KB |
|
||||
| **Total** | **< 0.7%** | **~10 MB** |
|
||||
|
||||
### Best Practices
|
||||
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
|
||||
2. Adjust completion log interval based on training length (100-200 steps)
|
||||
3. Keep completion buffer size reasonable (128-512)
|
||||
4. Profile critical path methods first (training_step, compute_loss)
|
||||
5. Use ProfilingConfig to throttle high-frequency operations
|
||||
|
||||
---
|
||||
|
||||
## Further Reading
|
||||
|
||||
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
|
||||
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
|
||||
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
|
||||
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
Found an issue or have an improvement? Please submit a PR or open an issue:
|
||||
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
|
||||
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)
|
||||
299
examples/swanlab/custom_trainer_profiling.py
Normal file
299
examples/swanlab/custom_trainer_profiling.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Example: Custom Trainer with SwanLab Profiling
|
||||
|
||||
This example demonstrates how to add SwanLab profiling to your custom trainer.
|
||||
|
||||
Features:
|
||||
- @swanlab_profile decorator for automatic profiling
|
||||
- swanlab_profiling_context for fine-grained profiling
|
||||
- ProfilingConfig for advanced filtering and throttling
|
||||
|
||||
Usage:
|
||||
1. Create your custom trainer extending AxolotlTrainer
|
||||
2. Add @swanlab_profile decorators to methods you want to profile
|
||||
3. Use swanlab_profiling_context for fine-grained profiling within methods
|
||||
4. Enable SwanLab in your config (use_swanlab: true)
|
||||
|
||||
See also:
|
||||
- examples/swanlab/lora-swanlab-profiling.yml for config
|
||||
- src/axolotl/integrations/swanlab/profiling.py for implementation
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.integrations.swanlab.profiling import (
|
||||
ProfilingConfig,
|
||||
swanlab_profile,
|
||||
swanlab_profiling_context,
|
||||
swanlab_profiling_context_advanced,
|
||||
)
|
||||
|
||||
|
||||
class CustomTrainerWithProfiling(AxolotlTrainer):
|
||||
"""Custom trainer with SwanLab profiling enabled.
|
||||
|
||||
This trainer demonstrates three profiling patterns:
|
||||
1. Decorator-based profiling (@swanlab_profile)
|
||||
2. Context manager profiling (swanlab_profiling_context)
|
||||
3. Advanced profiling with filtering (ProfilingConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Create custom profiling config for high-frequency operations
|
||||
self.fast_op_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.5, # Only log if duration > 0.5ms
|
||||
log_interval=50, # Log every 50th call
|
||||
)
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 1: Decorator-based Profiling
|
||||
# ========================================================================
|
||||
# Best for: Methods you always want to profile
|
||||
# Overhead: ~2-5 microseconds per call (negligible)
|
||||
|
||||
@swanlab_profile
|
||||
def training_step(self, model, inputs):
|
||||
"""Main training step - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||
"""
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
@swanlab_profile
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""Loss computation - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
|
||||
"""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@swanlab_profile
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
"""Prediction step - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
|
||||
"""
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 2: Fine-grained Context Manager Profiling
|
||||
# ========================================================================
|
||||
# Best for: Profiling specific code blocks within a method
|
||||
# Use case: When you want to profile forward vs backward separately
|
||||
|
||||
def complex_training_step(self, model, inputs):
|
||||
"""Training step with fine-grained profiling.
|
||||
|
||||
Profiling metrics:
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
|
||||
"""
|
||||
# Profile just the forward pass
|
||||
with swanlab_profiling_context(self, "forward_pass"):
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
|
||||
# Profile just the backward pass
|
||||
with swanlab_profiling_context(self, "backward_pass"):
|
||||
loss.backward()
|
||||
|
||||
# Profile optimizer step
|
||||
with swanlab_profiling_context(self, "optimizer_step"):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return outputs
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 3: Advanced Profiling with Filtering
|
||||
# ========================================================================
|
||||
# Best for: High-frequency operations where you want to throttle logging
|
||||
# Use case: Methods called 100+ times per step
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
"""Prepare inputs - throttled profiling.
|
||||
|
||||
This method is called frequently (once per batch), so we throttle
|
||||
profiling to reduce overhead:
|
||||
- Only log if duration > 0.5ms (skip very fast operations)
|
||||
- Only log every 50th call (reduce logging frequency)
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
|
||||
"""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_inputs", config=self.fast_op_config
|
||||
):
|
||||
return super()._prepare_inputs(inputs)
|
||||
|
||||
def _prepare_input_for_model(self, input_ids):
|
||||
"""Another high-frequency operation - throttled profiling.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
|
||||
"""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_input_for_model", config=self.fast_op_config
|
||||
):
|
||||
# Your custom input preparation logic
|
||||
return input_ids
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 4: Exception-safe Profiling
|
||||
# ========================================================================
|
||||
# Profiling is exception-safe: duration is logged even if method raises
|
||||
|
||||
@swanlab_profile
|
||||
def potentially_failing_method(self):
|
||||
"""This method may raise an exception.
|
||||
|
||||
SwanLab profiling will still log the duration before re-raising.
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
|
||||
"""
|
||||
# Do some work
|
||||
result = self._do_risky_computation()
|
||||
|
||||
# If this raises, profiling duration is still logged
|
||||
if result < 0:
|
||||
raise ValueError("Invalid result")
|
||||
|
||||
return result
|
||||
|
||||
def _do_risky_computation(self):
|
||||
"""Placeholder for risky computation."""
|
||||
return 42
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Advanced Example: Custom ProfilingConfig Per Method
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdvancedProfilingTrainer(AxolotlTrainer):
|
||||
"""Trainer with method-specific profiling configurations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Different profiling configs for different method types
|
||||
self.critical_path_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.0, # Log everything on critical path
|
||||
log_interval=1, # Log every call
|
||||
)
|
||||
|
||||
self.fast_path_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=1.0, # Only log if > 1ms
|
||||
log_interval=100, # Log every 100th call
|
||||
)
|
||||
|
||||
self.debug_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.0, # Log everything
|
||||
log_interval=1, # Log every call
|
||||
)
|
||||
|
||||
def training_step(self, model, inputs):
|
||||
"""Critical path - log everything."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "training_step", config=self.critical_path_config
|
||||
):
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
"""Fast path - throttle logging."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_inputs", config=self.fast_path_config
|
||||
):
|
||||
return super()._prepare_inputs(inputs)
|
||||
|
||||
def _debug_method(self, data):
|
||||
"""Debug-only method - verbose logging."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "debug_method", config=self.debug_config
|
||||
):
|
||||
# Your debug logic
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# How to Use This Custom Trainer
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
To use this custom trainer:
|
||||
|
||||
1. Save this file to your project (e.g., my_custom_trainer.py)
|
||||
|
||||
2. Create a config file that uses your custom trainer:
|
||||
|
||||
# config.yml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
# ... other config ...
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-profiling-experiment
|
||||
|
||||
# Optional: Specify custom trainer
|
||||
# (Or modify axolotl to use your custom trainer class)
|
||||
|
||||
3. Run training:
|
||||
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train config.yml
|
||||
|
||||
4. View profiling metrics in SwanLab dashboard:
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||
- etc.
|
||||
|
||||
5. Compare profiling metrics across runs:
|
||||
- Run baseline without optimizations
|
||||
- Run with flash_attention enabled
|
||||
- Run with gradient_checkpointing enabled
|
||||
- Compare profiling metrics to see performance impact
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# Tips for Effective Profiling
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
1. Profile the critical path first:
|
||||
- training_step, compute_loss, prediction_step
|
||||
- These methods are called most frequently and have biggest impact
|
||||
|
||||
2. Use throttling for high-frequency operations:
|
||||
- Methods called 100+ times per step
|
||||
- Use log_interval=50 or log_interval=100
|
||||
- Reduces profiling overhead and dashboard clutter
|
||||
|
||||
3. Filter noise with min_duration_ms:
|
||||
- Set min_duration_ms=1.0 to skip very fast operations
|
||||
- Focus on operations that actually take time
|
||||
|
||||
4. Compare across runs:
|
||||
- Run same config multiple times to check consistency
|
||||
- Compare different optimization strategies
|
||||
- Track profiling trends over time
|
||||
|
||||
5. Monitor distributed training:
|
||||
- Check for per-rank timing differences
|
||||
- Look for stragglers (slower ranks)
|
||||
- Identify synchronization bottlenecks
|
||||
|
||||
6. Disable profiling in production:
|
||||
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
|
||||
- DEFAULT_PROFILING_CONFIG.enabled = False
|
||||
|
||||
7. Exception handling:
|
||||
- Profiling is exception-safe
|
||||
- Duration logged even if method raises
|
||||
- Useful for debugging methods that fail intermittently
|
||||
"""
|
||||
168
examples/swanlab/dpo-swanlab-completions.yml
Normal file
168
examples/swanlab/dpo-swanlab-completions.yml
Normal file
@@ -0,0 +1,168 @@
|
||||
# SwanLab DPO Training Example with Completion Logging
|
||||
#
|
||||
# This example demonstrates DPO (Direct Preference Optimization) training
|
||||
# with SwanLab integration for experiment tracking and completion table logging.
|
||||
#
|
||||
# Features enabled:
|
||||
# - SwanLab experiment tracking
|
||||
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
|
||||
# - Lark (Feishu) team notifications (optional)
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||
|
||||
# Model Configuration
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
# Quantization
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
# LoRA Configuration
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# DPO Configuration
|
||||
chat_template: llama3
|
||||
rl: dpo
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
# Dataset and Output
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-swanlab-out
|
||||
|
||||
# Training Configuration
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 4
|
||||
|
||||
# Optimization
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# Precision
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# Performance
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# Checkpointing and Logging
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# Basic SwanLab Configuration
|
||||
use_swanlab: true
|
||||
swanlab_project: dpo-training
|
||||
swanlab_experiment_name: llama-3-dpo-completions-demo
|
||||
swanlab_description: "DPO training with completion table logging"
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# SwanLab Authentication
|
||||
# Recommended: Set via environment variable
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# Optional: Team workspace
|
||||
# swanlab_workspace: my-research-team
|
||||
|
||||
# ============================================================================
|
||||
# RLHF Completion Table Logging
|
||||
# ============================================================================
|
||||
#
|
||||
# Automatically logs model completions to SwanLab for qualitative analysis:
|
||||
# - Prompts from your DPO dataset
|
||||
# - Chosen responses (preferred)
|
||||
# - Rejected responses (non-preferred)
|
||||
# - Reward differences
|
||||
#
|
||||
# View the table in SwanLab dashboard under "rlhf_completions"
|
||||
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 training steps
|
||||
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
|
||||
|
||||
# Memory Usage Notes:
|
||||
# - Buffer size 128: ~64 KB (default, recommended)
|
||||
# - Buffer size 512: ~256 KB (for more historical completions)
|
||||
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
|
||||
|
||||
# Performance Notes:
|
||||
# - Completion logging overhead: < 0.5% per training step
|
||||
# - Only logs every N steps to minimize impact
|
||||
# - Memory-bounded buffer prevents memory leaks
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Lark (Feishu) Team Notifications
|
||||
# ============================================================================
|
||||
#
|
||||
# Get real-time training notifications in your team chat
|
||||
# Uncomment to enable:
|
||||
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret # Recommended for production
|
||||
|
||||
# Notifications sent for:
|
||||
# - Training start
|
||||
# - Training completion
|
||||
# - Training errors
|
||||
# - Metric milestones (if configured)
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Private SwanLab Deployment
|
||||
# ============================================================================
|
||||
#
|
||||
# For enterprise users with private SwanLab deployment:
|
||||
|
||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
|
||||
# ============================================================================
|
||||
# Disable WandB if you're migrating from it
|
||||
# ============================================================================
|
||||
|
||||
# wandb_project:
|
||||
# wandb_entity:
|
||||
# use_wandb: false
|
||||
329
examples/swanlab/dpo-swanlab-full-featured.yml
Normal file
329
examples/swanlab/dpo-swanlab-full-featured.yml
Normal file
@@ -0,0 +1,329 @@
|
||||
# SwanLab Full-Featured DPO Training Example
|
||||
#
|
||||
# This example demonstrates ALL SwanLab integration features:
|
||||
# - Experiment tracking with cloud sync
|
||||
# - RLHF completion table logging
|
||||
# - Performance profiling
|
||||
# - Lark (Feishu) team notifications
|
||||
# - Team workspace collaboration
|
||||
#
|
||||
# Use this as a reference for production RLHF training setups.
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||
|
||||
# ============================================================================
|
||||
# Model Configuration
|
||||
# ============================================================================
|
||||
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
# Quantization for efficient training
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
# ============================================================================
|
||||
# LoRA Configuration
|
||||
# ============================================================================
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true # Target all linear layers
|
||||
|
||||
# ============================================================================
|
||||
# DPO (Direct Preference Optimization) Configuration
|
||||
# ============================================================================
|
||||
|
||||
chat_template: llama3
|
||||
rl: dpo # Enable DPO trainer
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
# ============================================================================
|
||||
# Dataset and Output Configuration
|
||||
# ============================================================================
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-swanlab-full-featured-out
|
||||
|
||||
# ============================================================================
|
||||
# Training Configuration
|
||||
# ============================================================================
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 4
|
||||
|
||||
# ============================================================================
|
||||
# Optimization
|
||||
# ============================================================================
|
||||
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# ============================================================================
|
||||
# Precision and Performance
|
||||
# ============================================================================
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# ============================================================================
|
||||
# Checkpointing and Logging
|
||||
# ============================================================================
|
||||
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration - Full Configuration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Basic SwanLab Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: dpo-production
|
||||
swanlab_experiment_name: llama-3-dpo-full-featured-v1
|
||||
swanlab_description: |
|
||||
Production DPO training with all SwanLab features enabled:
|
||||
- Completion table logging for qualitative analysis
|
||||
- Performance profiling for optimization
|
||||
- Lark notifications for team collaboration
|
||||
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Team Collaboration
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Workspace for team collaboration (shared experiments)
|
||||
swanlab_workspace: ml-research-team
|
||||
|
||||
# Authentication (recommended: use environment variable)
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# RLHF Completion Table Logging
|
||||
# ------------------------------------------------------------------------------
|
||||
# Automatically logs model completions for qualitative analysis:
|
||||
# - Prompts from your DPO dataset
|
||||
# - Chosen responses (preferred)
|
||||
# - Rejected responses (non-preferred)
|
||||
# - Reward differences
|
||||
#
|
||||
# View in SwanLab dashboard under "rlhf_completions" table
|
||||
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
|
||||
|
||||
# Buffer size recommendations:
|
||||
# - 128: Default, ~64 KB memory (recommended for most cases)
|
||||
# - 256: ~128 KB memory (this config, good for longer training)
|
||||
# - 512: ~256 KB memory (maximum for very long runs)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Lark (Feishu) Team Notifications
|
||||
# ------------------------------------------------------------------------------
|
||||
# Get real-time training notifications in your team chat
|
||||
#
|
||||
# Notifications sent for:
|
||||
# - Training start
|
||||
# - Training completion
|
||||
# - Training errors
|
||||
# - Metric milestones (if configured)
|
||||
|
||||
# Recommended: Set via environment variables
|
||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
|
||||
# Or set in config (less secure):
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
|
||||
|
||||
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
|
||||
# unauthorized parties from sending fake notifications to your team chat.
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Performance Profiling
|
||||
# ------------------------------------------------------------------------------
|
||||
# Profiling is automatically enabled when SwanLab is enabled.
|
||||
# Metrics logged to SwanLab under "profiling/" namespace:
|
||||
# profiling/Time taken: AxolotlTrainer.training_step
|
||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||
#
|
||||
# Use these metrics to:
|
||||
# - Identify bottlenecks in training loop
|
||||
# - Compare performance across different configurations
|
||||
# - Monitor performance regressions over time
|
||||
# - Debug unexpected slowdowns
|
||||
|
||||
# For custom profiling in your own trainer, see:
|
||||
# examples/swanlab/custom_trainer_profiling.py
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Optional: Private SwanLab Deployment
|
||||
# ------------------------------------------------------------------------------
|
||||
# For enterprise users with private SwanLab deployment:
|
||||
|
||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Optional: Model Checkpointing to SwanLab
|
||||
# ------------------------------------------------------------------------------
|
||||
# Log model checkpoints to SwanLab (coming soon)
|
||||
|
||||
swanlab_log_model: false
|
||||
|
||||
# ============================================================================
|
||||
# Disable Other Logging Tools (Recommended)
|
||||
# ============================================================================
|
||||
# Using multiple logging tools simultaneously can impact performance:
|
||||
# - Expected overhead: ~1-2% per logger
|
||||
# - Potential config/callback conflicts
|
||||
#
|
||||
# For production training, use ONLY SwanLab:
|
||||
|
||||
# wandb_project:
|
||||
# use_wandb: false
|
||||
#
|
||||
# use_mlflow: false
|
||||
#
|
||||
# use_comet: false
|
||||
|
||||
# ============================================================================
|
||||
# Expected Training Behavior
|
||||
# ============================================================================
|
||||
|
||||
# With this configuration, you should see:
|
||||
#
|
||||
# 1. SwanLab Initialization (rank 0 only):
|
||||
# INFO: SwanLab initialized for project: dpo-production
|
||||
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
|
||||
# INFO: SwanLab mode: cloud
|
||||
# INFO: SwanLab workspace: ml-research-team
|
||||
#
|
||||
# 2. Completion Logging (rank 0 only):
|
||||
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
|
||||
# (log_interval=100, max_buffer=256)
|
||||
#
|
||||
# 3. Lark Notifications (rank 0 only):
|
||||
# INFO: Registered Lark notification callback with HMAC authentication
|
||||
#
|
||||
# 4. Distributed Training Detection (if multi-GPU):
|
||||
# INFO: Distributed training detected (world_size=N)
|
||||
# INFO: Only rank 0 will initialize SwanLab
|
||||
# INFO: Other ranks will skip SwanLab to avoid conflicts
|
||||
#
|
||||
# 5. Training Start Notification (Lark):
|
||||
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
|
||||
#
|
||||
# 6. Periodic Completion Logging:
|
||||
# Every 100 steps, completion table is updated in SwanLab dashboard
|
||||
#
|
||||
# 7. Training Complete Notification (Lark):
|
||||
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
|
||||
# With link to SwanLab dashboard and final metrics
|
||||
#
|
||||
# 8. SwanLab Dashboard Shows:
|
||||
# - Training metrics (loss, learning rate, etc.)
|
||||
# - Completion table (rlhf_completions)
|
||||
# - Profiling metrics (profiling/Time taken: ...)
|
||||
# - Hyperparameters and configuration
|
||||
# - System resource usage
|
||||
|
||||
# ============================================================================
|
||||
# Production Checklist
|
||||
# ============================================================================
|
||||
|
||||
# Before deploying to production, verify:
|
||||
# ✅ SwanLab API key is set via environment variable (not in config)
|
||||
# ✅ Lark webhook secret is set (required for HMAC authentication)
|
||||
# ✅ Workspace is set to your team's workspace
|
||||
# ✅ Experiment name is descriptive and unique
|
||||
# ✅ Only SwanLab is enabled (other loggers disabled)
|
||||
# ✅ Completion logging buffer size is appropriate for your training duration
|
||||
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
|
||||
# ✅ Test run completes successfully and shows up in SwanLab dashboard
|
||||
# ✅ Lark notifications are received in team chat
|
||||
# ✅ Profiling metrics are logged correctly
|
||||
|
||||
# ============================================================================
|
||||
# Troubleshooting
|
||||
# ============================================================================
|
||||
|
||||
# If SwanLab initialization fails:
|
||||
# 1. Check SWANLAB_API_KEY environment variable is set
|
||||
# 2. Verify swanlab_project is set in config
|
||||
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
|
||||
# 4. Verify internet connectivity (for cloud mode)
|
||||
|
||||
# If Lark notifications not received:
|
||||
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
|
||||
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
|
||||
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
|
||||
# 4. Check training logs for "Registered Lark notification callback"
|
||||
# 5. Verify bot is added to the target Lark group chat
|
||||
|
||||
# If completions not appearing in SwanLab:
|
||||
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||
# 2. Check swanlab_log_completions is true
|
||||
# 3. Wait for log_interval steps (default: 100)
|
||||
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
|
||||
|
||||
# If profiling metrics not appearing:
|
||||
# 1. Verify use_swanlab is true
|
||||
# 2. Check SwanLab is initialized (check logs)
|
||||
# 3. Look under "profiling/" namespace in dashboard
|
||||
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
|
||||
|
||||
# For more help:
|
||||
# - SwanLab docs: https://docs.swanlab.cn
|
||||
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
|
||||
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues
|
||||
178
examples/swanlab/lora-swanlab-profiling.yml
Normal file
178
examples/swanlab/lora-swanlab-profiling.yml
Normal file
@@ -0,0 +1,178 @@
|
||||
# SwanLab LoRA Training Example with Performance Profiling
|
||||
#
|
||||
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
|
||||
# for performance profiling and optimization.
|
||||
#
|
||||
# Features enabled:
|
||||
# - SwanLab experiment tracking
|
||||
# - Performance profiling (training step, forward/backward pass timing)
|
||||
# - Real-time metrics visualization
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||
|
||||
# Model Configuration
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
# Dataset Configuration
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-swanlab-profiling-out
|
||||
|
||||
# LoRA Configuration
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Training Configuration
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
|
||||
# Optimization
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# Precision
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# Performance
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# Checkpointing and Logging
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Loss Monitoring
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# Basic SwanLab Configuration
|
||||
use_swanlab: true
|
||||
swanlab_project: lora-profiling
|
||||
swanlab_experiment_name: llama-3.2-1b-profiling-demo
|
||||
swanlab_description: "LoRA fine-tuning with performance profiling"
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# SwanLab Authentication
|
||||
# Recommended: Set via environment variable
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# Optional: Team workspace
|
||||
# swanlab_workspace: my-ml-team
|
||||
|
||||
# ============================================================================
|
||||
# Performance Profiling
|
||||
# ============================================================================
|
||||
#
|
||||
# SwanLab automatically profiles trainer methods when enabled.
|
||||
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
|
||||
#
|
||||
# Built-in profiling:
|
||||
# - Minimal overhead (< 0.1% per step)
|
||||
# - High-precision timing (microsecond accuracy)
|
||||
# - Exception-safe (logs duration even if method fails)
|
||||
#
|
||||
# View profiling metrics in SwanLab dashboard:
|
||||
# profiling/Time taken: AxolotlTrainer.training_step
|
||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||
#
|
||||
# For custom profiling in your own trainer, see:
|
||||
# examples/swanlab/custom_trainer_profiling.py
|
||||
|
||||
# Completion logging is disabled for non-RLHF trainers
|
||||
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Compare with Multiple Runs
|
||||
# ============================================================================
|
||||
#
|
||||
# To compare profiling metrics across different configurations:
|
||||
#
|
||||
# 1. Run baseline without flash attention:
|
||||
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
|
||||
# flash_attention: false
|
||||
#
|
||||
# 2. Run with gradient checkpointing:
|
||||
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
|
||||
# gradient_checkpointing: true
|
||||
#
|
||||
# 3. Run with both:
|
||||
# swanlab_experiment_name: llama-3.2-1b-optimized
|
||||
# flash_attention: true
|
||||
# gradient_checkpointing: true
|
||||
#
|
||||
# Then compare profiling metrics in SwanLab dashboard to see performance impact
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Lark (Feishu) Team Notifications
|
||||
# ============================================================================
|
||||
#
|
||||
# Get notified when profiling experiments complete:
|
||||
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret
|
||||
|
||||
# ============================================================================
|
||||
# Profiling Best Practices
|
||||
# ============================================================================
|
||||
#
|
||||
# 1. Run multiple epochs to see profiling trends over time
|
||||
# 2. Ignore first ~10 steps (warmup period, slower)
|
||||
# 3. Look for outliers (steps that take significantly longer)
|
||||
# 4. Compare profiling metrics before/after optimization changes
|
||||
# 5. Monitor per-rank profiling in distributed training
|
||||
#
|
||||
# Common bottlenecks to profile:
|
||||
# - training_step: Overall step time (should be consistent)
|
||||
# - compute_loss: Loss computation (scales with sequence length)
|
||||
# - prediction_step: Evaluation time (can be slow for large val sets)
|
||||
#
|
||||
# If you see inconsistent timing:
|
||||
# - Check for data loading bottlenecks
|
||||
# - Monitor GPU utilization (may be CPU-bound)
|
||||
# - Check for gradient accumulation effects
|
||||
# - Verify CUDA kernel synchronization
|
||||
|
||||
# ============================================================================
|
||||
# Disable WandB if you're migrating from it
|
||||
# ============================================================================
|
||||
|
||||
# wandb_project:
|
||||
# use_wandb: false
|
||||
@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
|
||||
@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
@@ -24,6 +24,9 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { file = "VERSION" }
|
||||
|
||||
[tool.setuptools.cmdclass]
|
||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||
|
||||
@@ -57,3 +60,8 @@ indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.48.2
|
||||
triton>=3.0.0
|
||||
bitsandbytes==0.49.1
|
||||
triton>=3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.6.4
|
||||
liger-kernel==0.7.0
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.36.0
|
||||
peft>=0.18.0
|
||||
packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==4.57.1
|
||||
transformers==5.2.0
|
||||
accelerate==1.12.0
|
||||
datasets==4.4.2
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.3
|
||||
trl==0.25.1
|
||||
trl==0.28.0
|
||||
hf_xet==1.2.0
|
||||
kernels==0.11.5
|
||||
trackio>=0.13.0
|
||||
kernels==0.12.1
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.13.0
|
||||
torchao==0.16.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.8.6
|
||||
mistral-common==1.8.8
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
|
||||
)
|
||||
|
||||
67
setup.py
67
setup.py
@@ -1,6 +1,5 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@@ -26,6 +25,12 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
install_xformers = platform.machine() != "aarch64"
|
||||
if platform.machine() == "aarch64":
|
||||
# skip torchao on ARM64
|
||||
_install_requires = [
|
||||
req for req in _install_requires if "torchao" not in req
|
||||
]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -62,44 +67,68 @@ def parse_requirements(extras_require_map):
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
torch_parts = torch_version.split("+")
|
||||
if len(torch_parts) == 2:
|
||||
torch_cuda_version = torch_parts[1]
|
||||
_dependency_links.append(
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
extras_require_map["vllm"] = ["vllm==0.14.0"]
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
# vllm 0.9.x is incompatible with latest transformers
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
if install_xformers:
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
if install_xformers:
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
else:
|
||||
raise ValueError("axolotl requires torch>=2.4")
|
||||
|
||||
@@ -110,15 +139,11 @@ def parse_requirements(extras_require_map):
|
||||
|
||||
def get_package_version():
|
||||
with open(
|
||||
Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
/ "src"
|
||||
/ "axolotl"
|
||||
/ "__init__.py",
|
||||
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
|
||||
version_ = ast.literal_eval(version_match.group(1))
|
||||
version_ = fin.read().strip()
|
||||
return version_
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Axolotl - Train and fine-tune large language models"""
|
||||
|
||||
import pkgutil
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.13.0.dev"
|
||||
try:
|
||||
__version__ = version("axolotl")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
|
||||
@@ -5,6 +5,6 @@ import os
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -44,7 +44,7 @@ def check_user_token() -> bool:
|
||||
return bool(user_info)
|
||||
except LocalTokenNotFoundError:
|
||||
LOG.warning(
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Union
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
|
||||
"""Coerce a string CLI value to its most likely Python type.
|
||||
|
||||
If an existing value is present in the config, its type is used to guide
|
||||
casting. Otherwise, YAML-style inference is applied: booleans, ints,
|
||||
floats, and None literals are recognised automatically.
|
||||
|
||||
Args:
|
||||
value: The raw value (typically a string from the CLI).
|
||||
existing: An optional existing config value whose type guides coercion.
|
||||
|
||||
Returns:
|
||||
The value cast to the inferred or expected type.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
# If the config already has a typed value, cast to match
|
||||
if existing is not None:
|
||||
if isinstance(existing, bool):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
if isinstance(existing, int):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
if isinstance(existing, float):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
# For other types (str, list, dict, etc.), return as-is
|
||||
return value
|
||||
|
||||
# No existing value -- use YAML-style inference
|
||||
lower = value.lower()
|
||||
if lower in ("true", "yes"):
|
||||
return True
|
||||
if lower in ("false", "no"):
|
||||
return False
|
||||
if lower in ("null", "none", "~"):
|
||||
return None
|
||||
|
||||
# Try int then float
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -208,13 +265,37 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
|
||||
# Separate nested (dot-notation) kwargs from flat kwargs
|
||||
nested_kwargs: dict[str, dict[str, Any]] = {}
|
||||
flat_kwargs: dict[str, Any] = {}
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parent, child = key.split("__", 1)
|
||||
nested_kwargs.setdefault(parent, {})[child] = value
|
||||
else:
|
||||
flat_kwargs[key] = value
|
||||
|
||||
# Apply flat kwargs
|
||||
for key, value in flat_kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[key] = value
|
||||
cfg[key] = _coerce_value(value, cfg.get(key))
|
||||
|
||||
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
|
||||
for parent, children in nested_kwargs.items():
|
||||
if parent not in cfg_keys and cfg.strict:
|
||||
continue
|
||||
if cfg[parent] is None:
|
||||
cfg[parent] = {}
|
||||
if not isinstance(cfg[parent], dict):
|
||||
LOG.warning(
|
||||
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
|
||||
)
|
||||
cfg[parent] = {}
|
||||
for child_key, child_value in children.items():
|
||||
existing_child = cfg[parent].get(child_key)
|
||||
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -24,7 +24,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
model = model.merge_and_unload(progressbar=True)
|
||||
@@ -42,7 +41,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
|
||||
@@ -14,8 +14,6 @@ from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
@@ -40,17 +38,15 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||
def _distributed_checkpoint_to_merged_weights(
|
||||
checkpoint_dir: Union[str, Path],
|
||||
save_path: str,
|
||||
safe_serialization: bool = False,
|
||||
max_shard_size: str = "5GB",
|
||||
) -> Path:
|
||||
"""
|
||||
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
|
||||
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
||||
save under `save_path` as `model.safetensors`.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory where distributed checkpoint is saved.
|
||||
save_path: Path to save model to.
|
||||
safe_serialization: Whether to save in safetensors format.
|
||||
max_shard_size: Max size of model shards to save.
|
||||
|
||||
Returns:
|
||||
@@ -76,11 +72,7 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
|
||||
state_dict[key] = value.to(torch.bfloat16)
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
@@ -98,19 +90,12 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
||||
|
||||
if safe_serialization:
|
||||
safe_save_file(
|
||||
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
else:
|
||||
torch.save(shard, os.path.join(save_path_, shard_file))
|
||||
safe_save_file(
|
||||
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
|
||||
if index is not None:
|
||||
save_index_file = (
|
||||
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
)
|
||||
save_index_file = os.path.join(save_path_, save_index_file)
|
||||
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as fout:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@@ -123,13 +108,11 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
safe_serialization: bool = False,
|
||||
remove_checkpoint_dir: bool = False,
|
||||
):
|
||||
"""
|
||||
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
|
||||
`safe_serialization` else `pytorch_model.bin`.
|
||||
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
|
||||
|
||||
Note: this is a CPU-bound process.
|
||||
|
||||
@@ -138,8 +121,6 @@ def merge_fsdp_weights(
|
||||
The directory containing the FSDP checkpoints (can be either the model or optimizer).
|
||||
output_path (`str`):
|
||||
The path to save the merged checkpoint.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the merged weights with safetensors (recommended).
|
||||
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether to remove the checkpoint directory after merging.
|
||||
|
||||
@@ -177,7 +158,7 @@ def merge_fsdp_weights(
|
||||
if state.is_main_process:
|
||||
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
|
||||
save_path = _distributed_checkpoint_to_merged_weights(
|
||||
checkpoint_dir_, output_path, safe_serialization
|
||||
checkpoint_dir_, output_path
|
||||
)
|
||||
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
|
||||
if remove_checkpoint_dir:
|
||||
@@ -210,7 +191,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
|
||||
@@ -102,12 +102,10 @@ def do_quantize(
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
@@ -121,7 +119,7 @@ def do_quantize(
|
||||
hub_model_id.rstrip("-")
|
||||
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
model.push_to_hub(hub_model_id)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
if processor:
|
||||
processor.push_to_hub(hub_model_id)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
|
||||
if is_union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_pydantic_model(field_type: type) -> bool:
|
||||
"""Check if a type is a Pydantic BaseModel subclass."""
|
||||
try:
|
||||
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_field_description(field) -> str | None:
|
||||
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
|
||||
if field.description:
|
||||
return field.description
|
||||
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
|
||||
return field.json_schema_extra.get("description")
|
||||
return None
|
||||
|
||||
|
||||
def _add_nested_model_options(
|
||||
function: Callable, parent_name: str, model_class: Type[BaseModel]
|
||||
) -> Callable:
|
||||
"""
|
||||
Add Click options for all fields of a nested Pydantic model using dot-notation.
|
||||
|
||||
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
|
||||
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
|
||||
|
||||
Args:
|
||||
function: Click command function to add options to.
|
||||
parent_name: Parent field name (e.g., "trl").
|
||||
model_class: Nested Pydantic model class.
|
||||
|
||||
Returns:
|
||||
Function with added Click options.
|
||||
"""
|
||||
for sub_name, sub_field in reversed(model_class.model_fields.items()):
|
||||
sub_type = _strip_optional_type(sub_field.annotation)
|
||||
# Use dot notation: --parent.sub_field
|
||||
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
|
||||
# The kwarg name uses double-underscore as separator
|
||||
param_name = f"{parent_name}__{sub_name}"
|
||||
description = _get_field_description(sub_field)
|
||||
|
||||
if sub_type is bool:
|
||||
option_name = f"--{cli_name}/--no-{cli_name}"
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, help=description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{cli_name}"
|
||||
click_type = {str: str, int: int, float: float}.get(sub_type)
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, type=click_type, help=description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
|
||||
generated for each sub-field (e.g., ``--trl.beta=0.1``).
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
# Handle nested Pydantic models with dot-notation options
|
||||
if _is_pydantic_model(field_type):
|
||||
function = _add_nested_model_options(function, name, field_type)
|
||||
continue
|
||||
|
||||
if field_type is bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
|
||||
@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
def _configure_warmup_and_logging(
|
||||
self, total_num_steps: int, training_args_kwargs: dict
|
||||
):
|
||||
warmup_steps = 0
|
||||
warmup_steps: int | float = 0
|
||||
warmup_ratio = 0.0
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
@@ -230,6 +230,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
else:
|
||||
warmup_ratio = 0.03
|
||||
|
||||
# transformers v5
|
||||
if warmup_ratio > 0.0 and warmup_steps == 0:
|
||||
warmup_steps = warmup_ratio
|
||||
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
@@ -242,7 +246,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
|
||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||
@@ -406,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.hub_strategy:
|
||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.hub_revision:
|
||||
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
|
||||
|
||||
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
|
||||
# save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
@@ -530,9 +536,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"loraplus_lr_ratio",
|
||||
"loraplus_lr_embedding",
|
||||
"output_dir",
|
||||
"save_safetensors",
|
||||
"save_only_model",
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
@@ -545,6 +549,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
"include_num_input_tokens_seen": "include_tokens_per_second",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
|
||||
@@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
if getattr(self.cfg, "generate_samples", False):
|
||||
from axolotl.utils.callbacks.generation import SFTGenerationCallback
|
||||
|
||||
callbacks.append(SFTGenerationCallback(trainer))
|
||||
LOG.info("SFT sample generation enabled")
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
@@ -246,7 +252,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
if self.cfg.group_by_length:
|
||||
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
@@ -373,6 +380,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
if self.cfg.use_eaft:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.loss.eaft import eaft_loss
|
||||
|
||||
configured_eaft_loss = partial(
|
||||
eaft_loss,
|
||||
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
|
||||
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
|
||||
)
|
||||
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
@@ -437,7 +456,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
|
||||
self.cfg.micro_batch_size == 1 and is_eval is False
|
||||
):
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
|
||||
@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
@@ -52,12 +51,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
@@ -134,19 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
self.cfg.kto_desirable_weight or 1.0
|
||||
@@ -155,10 +153,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
training_args_kwargs.setdefault(
|
||||
"multi_objective_aggregation", "normalize_then_sum"
|
||||
)
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
|
||||
@@ -25,7 +25,7 @@ from torch.utils.data import (
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -660,11 +660,10 @@ class AxolotlTrainer(
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if (
|
||||
hasattr(self.state, "total_tokens")
|
||||
and self.state.total_tokens is not None
|
||||
):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
if "total" in self.state.tokens:
|
||||
logs["tokens/total"] = int(self.state.tokens["total"].item())
|
||||
if "trainable" in self.state.tokens:
|
||||
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
@@ -720,6 +719,20 @@ class AxolotlTrainer(
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||
if (
|
||||
state_dict is not None
|
||||
and self.axolotl_cfg
|
||||
and self.axolotl_cfg.context_parallel_size
|
||||
and self.axolotl_cfg.context_parallel_size > 1
|
||||
):
|
||||
state_dict = {
|
||||
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
@@ -730,6 +743,7 @@ class AxolotlTrainer(
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
@@ -739,43 +753,35 @@ class AxolotlTrainer(
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
is_chat: bool = False,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
max_prompt_length=max_prompt_length,
|
||||
max_completion_length=max_completion_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_chat=is_chat,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
|
||||
@@ -126,8 +126,10 @@ class GRPOStrategy:
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
|
||||
if trl.rollout_func:
|
||||
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@@ -149,6 +151,8 @@ class GRPOStrategy:
|
||||
trainer_kwargs["reward_processing_classes"] = (
|
||||
cfg.trl.reward_processing_classes
|
||||
)
|
||||
if cfg.trl and cfg.trl.rollout_func:
|
||||
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
|
||||
|
||||
return trainer_kwargs
|
||||
|
||||
@@ -159,7 +163,12 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls) -> list[str]:
|
||||
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
|
||||
return [
|
||||
"dataset_num_proc",
|
||||
"max_length",
|
||||
"include_tokens_per_second",
|
||||
"max_prompt_length",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
|
||||
@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
|
||||
) -> LRScheduler:
|
||||
"""
|
||||
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
@@ -45,6 +45,13 @@ class SchedulerMixin(Trainer):
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
if optimizer is None:
|
||||
if self.optimizer is None:
|
||||
raise ValueError(
|
||||
"Optimizer must be set before calling create_scheduler or passed as an argument."
|
||||
)
|
||||
optimizer = self.optimizer
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore
|
||||
# fmt: on
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Module for TRL RL trainers"""
|
||||
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl import RewardTrainer
|
||||
from trl.experimental.cpo import CPOTrainer
|
||||
from trl.experimental.kto import KTOTrainer
|
||||
from trl.experimental.orpo import ORPOTrainer
|
||||
from trl.experimental.prm import PRMTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
|
||||
@@ -8,7 +8,11 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
from trl import RewardConfig
|
||||
from trl.experimental.cpo import CPOConfig
|
||||
from trl.experimental.kto import KTOConfig
|
||||
from trl.experimental.orpo import ORPOConfig
|
||||
from trl.experimental.prm import PRMConfig
|
||||
|
||||
from axolotl.integrations.config import merge_training_args
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,11 +31,13 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- afmoe
|
||||
- apertus
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
- deepseek_v3
|
||||
- exaone4
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
@@ -45,13 +47,17 @@ plugins:
|
||||
- glm
|
||||
- glm4
|
||||
- glm4_moe
|
||||
- glm4_moe_lite
|
||||
- glm46v
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
- glm_image
|
||||
- glm_moe_dsa
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- granitemoeshared
|
||||
- granitemoehybrid
|
||||
- granitemoeshared
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- internvl
|
||||
@@ -72,20 +78,26 @@ plugins:
|
||||
- olmo
|
||||
- olmo2
|
||||
- olmo3
|
||||
- olmoe
|
||||
- phi
|
||||
- phi3
|
||||
- phi4_multimodal
|
||||
- qwen2
|
||||
- qwen2_vl
|
||||
- qwen2_moe
|
||||
- qwen2_5_vl
|
||||
- qwen2_moe
|
||||
- qwen2_vl
|
||||
- qwen3
|
||||
- qwen3_5
|
||||
- qwen3_5_text
|
||||
- qwen3_5_moe
|
||||
- qwen3_5_moe_text
|
||||
- qwen3_moe
|
||||
- qwen3_next
|
||||
- qwen3_vl
|
||||
- qwen3_vl_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- smollm3
|
||||
- step3p5
|
||||
- voxtral
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
|
||||
def patch_llama_like(
|
||||
self,
|
||||
model_type: str,
|
||||
model_type_to_patch: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generic patch for model architectures with causal lm similar to llama
|
||||
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||
|
||||
def patch_generic(
|
||||
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||
maybe_model,
|
||||
patch_options,
|
||||
remote_model_id: str | None,
|
||||
model_type: str,
|
||||
):
|
||||
import cut_cross_entropy.transformers.llama
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
if model_type not in PATCH_FNS:
|
||||
if model_type_to_patch not in PATCH_FNS:
|
||||
LOG.warning_once(
|
||||
"Setting up generic cce patch for model type: %s", model_type
|
||||
"Setting up generic cce patch for model type: %s", model_type_to_patch
|
||||
)
|
||||
LOG.warning_once(
|
||||
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
||||
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
|
||||
)
|
||||
PATCH_FNS[model_type_to_patch] = partial(
|
||||
patch_generic, model_type=model_type_to_patch
|
||||
)
|
||||
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
||||
|
||||
46
src/axolotl/integrations/kernels/README.md
Normal file
46
src/axolotl/integrations/kernels/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Kernels Integration
|
||||
|
||||
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
|
||||
|
||||
```python
|
||||
class ExpertsInterface(GeneralInterface):
|
||||
_global_mapping = {
|
||||
"batched_mm": batched_mm_experts_forward,
|
||||
"grouped_mm": grouped_mm_experts_forward,
|
||||
}
|
||||
```
|
||||
|
||||
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
|
||||
|
||||
## Usage
|
||||
|
||||
Add the following to your axolotl YAML config:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
|
||||
use_kernels: true
|
||||
use_scattermoe: true
|
||||
```
|
||||
|
||||
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
|
||||
|
||||
## How It Works
|
||||
|
||||
The `KernelsPlugin` runs before model loading and:
|
||||
|
||||
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||
|
||||
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
|
||||
|
||||
## Limitations
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
||||
7
src/axolotl/integrations/kernels/__init__.py
Normal file
7
src/axolotl/integrations/kernels/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .args import KernelsArgs
|
||||
from .plugin import KernelsPlugin
|
||||
|
||||
__all__ = [
|
||||
"KernelsArgs",
|
||||
"KernelsPlugin",
|
||||
]
|
||||
48
src/axolotl/integrations/kernels/args.py
Normal file
48
src/axolotl/integrations/kernels/args.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class KernelsArgs(BaseModel):
|
||||
use_scattermoe: bool | None = True
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_use_kernels(cls, data):
|
||||
if data.get("use_kernels") is not True:
|
||||
LOG.warning(
|
||||
"`use_kernels` must be set to True to use this. Automatically setting it to True."
|
||||
)
|
||||
data["use_kernels"] = True
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_experts_implementation(cls, data):
|
||||
experts_implementation = data.get("experts_implementation")
|
||||
if experts_implementation is None:
|
||||
# transformers may default to batched_mm when unset
|
||||
data["experts_implementation"] = "eager"
|
||||
elif experts_implementation != "eager":
|
||||
LOG.warning(
|
||||
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
|
||||
)
|
||||
data["experts_implementation"] = "eager"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel_scattermoe(cls, data):
|
||||
if data.get("use_scattermoe") is True:
|
||||
if data.get("lora_mlp_kernel") is True:
|
||||
LOG.warning(
|
||||
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
|
||||
)
|
||||
data["lora_mlp_kernel"] = False
|
||||
data["mlp_kernel"] = False
|
||||
|
||||
return data
|
||||
0
src/axolotl/integrations/kernels/libs/__init__.py
Normal file
0
src/axolotl/integrations/kernels/libs/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
|
||||
__all__ = [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
]
|
||||
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
#
|
||||
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import lora_ops, ops
|
||||
|
||||
__all__ = ["ops", "lora_ops"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,645 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
BLOCK_M = 128
|
||||
ALLOW_TF32 = True
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_expert_block(
|
||||
E_idx,
|
||||
E_mask,
|
||||
M_in_idx,
|
||||
N_block,
|
||||
N_mask,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
K,
|
||||
acc,
|
||||
no_k_mask,
|
||||
BLOCK_K,
|
||||
allow_tf32=True,
|
||||
):
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
|
||||
W_blk_ptrs = (
|
||||
W_ptr
|
||||
+ K_block[:, None] * stride_wk
|
||||
+ N_block[None, :] * stride_wn
|
||||
+ E_idx * stride_we
|
||||
)
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
|
||||
for K_block_id in range(iters):
|
||||
if no_k_mask:
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
|
||||
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
|
||||
else:
|
||||
K_mask = (K_block_id * BLOCK_K + K_block) < K
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
|
||||
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
|
||||
return acc
|
||||
|
||||
|
||||
def _scatter2scatter_configs():
|
||||
return [
|
||||
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_scatter2scatter_configs(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
|
||||
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _scatter2scatter(
|
||||
X_ptr,
|
||||
stride_xm: tl.constexpr,
|
||||
stride_xk: tl.constexpr,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk: tl.constexpr,
|
||||
stride_wn: tl.constexpr,
|
||||
Y_ptr,
|
||||
stride_ym: tl.constexpr,
|
||||
stride_yn: tl.constexpr,
|
||||
B_ptr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
grouped_idx_ptr,
|
||||
expert_idxs_ptr,
|
||||
# block_start_idx_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
E: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
# OUT_M,
|
||||
allow_tf32: tl.constexpr,
|
||||
x_grouped: tl.constexpr,
|
||||
y_grouped: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
NO_N_MASK: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
|
||||
M_block_id = pid // N_BLOCK_COUNT
|
||||
N_block_id = pid % N_BLOCK_COUNT
|
||||
|
||||
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
M_boundary_mask = M_block < (FAN_OUT * M)
|
||||
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
|
||||
|
||||
no_k_mask = K % BLOCK_K == 0
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
E_first_idx = tl.min(E_idxs)
|
||||
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
|
||||
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
|
||||
for E_idx in range(E_first_idx, E_last_idx + 1):
|
||||
E_mask = E_idxs == E_idx
|
||||
E_M_idx = M_idx
|
||||
if x_grouped:
|
||||
M_in_idx = M_block
|
||||
else:
|
||||
M_in_idx = E_M_idx // FAN_OUT
|
||||
acc = _compute_expert_block(
|
||||
E_idx,
|
||||
E_mask,
|
||||
M_in_idx,
|
||||
N_block,
|
||||
N_mask,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
K,
|
||||
acc,
|
||||
no_k_mask,
|
||||
BLOCK_K,
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
|
||||
if B_ptr is not None:
|
||||
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
|
||||
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
if y_grouped:
|
||||
M_out_idx = M_block
|
||||
else:
|
||||
M_out_idx = M_idx
|
||||
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def scatter2scatter(
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
b=None,
|
||||
x_grouped=False,
|
||||
y_grouped=False,
|
||||
out=None,
|
||||
):
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
# Pre-kernel setup
|
||||
y_dim = W.size(-1)
|
||||
L_scattered = sorted_expert_idxs.size(0)
|
||||
if out is None:
|
||||
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
|
||||
else:
|
||||
assert out.size(0) == L_scattered and out.size(1) == y_dim
|
||||
output = out
|
||||
|
||||
scatter2scatter_compileable(
|
||||
output,
|
||||
W,
|
||||
X,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
|
||||
def scatter2scatter_compileable(
|
||||
output: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
b: Optional[torch.Tensor],
|
||||
x_grouped: bool,
|
||||
y_grouped: bool,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid_num = (
|
||||
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
|
||||
* triton.cdiv(META["N"], META["BLOCK_N"]),
|
||||
)
|
||||
return grid_num
|
||||
|
||||
if b is None:
|
||||
b = None
|
||||
stride_be = stride_bn = 0
|
||||
else:
|
||||
stride_be, stride_bn = b.stride()
|
||||
|
||||
_scatter2scatter[grid](
|
||||
# X_ptr, stride_xm, stride_xk,
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
# W_ptr, stride_we, stride_wk, stride_wn,
|
||||
W,
|
||||
W.stride(0),
|
||||
W.stride(1),
|
||||
W.stride(2),
|
||||
# Y_ptr, stride_ym, stride_yn,
|
||||
output,
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
# B_ptr, stride_be, stride_bn
|
||||
b,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
grouped_idx_ptr=sorted_scattered_idxs,
|
||||
expert_idxs_ptr=sorted_expert_idxs,
|
||||
# block_start_idx_ptr=padded_block_idxs,
|
||||
FAN_OUT=k,
|
||||
M=X.size(0),
|
||||
K=X.size(1),
|
||||
N=output.size(1),
|
||||
E=W.size(0),
|
||||
BLOCK_M=BLOCK_M,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
|
||||
def _config_XtY():
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
|
||||
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
DW = DWt.permute(0, 2, 1)
|
||||
if has_bias:
|
||||
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
else:
|
||||
Db = None
|
||||
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
|
||||
return DW, Db
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
|
||||
def groupXtY_compileable(
|
||||
E: int,
|
||||
DW: torch.Tensor,
|
||||
Db: Optional[torch.Tensor],
|
||||
DY: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid = (
|
||||
E * triton.cdiv(META["K"], META["BLOCK_K"]),
|
||||
triton.cdiv(META["N"], META["BLOCK_N"]),
|
||||
)
|
||||
return grid
|
||||
|
||||
if Db is None:
|
||||
stride_dbe = 0
|
||||
stride_dbn = 0
|
||||
else:
|
||||
stride_dbe, stride_dbn = Db.stride()
|
||||
|
||||
_groupXtY[grid](
|
||||
# DY_ptr, stride_dym, stride_dyk,
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
# X_ptr, stride_xm, stride_xn,
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
DW,
|
||||
DW.stride(0),
|
||||
DW.stride(1),
|
||||
DW.stride(2),
|
||||
# Db_ptr, stride_dwe, stride_dbn,
|
||||
Db,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
# expert_offsets_ptr,
|
||||
expert_offsets,
|
||||
# K: tl.constexpr, N: tl.constexpr,
|
||||
M=DY.size(0),
|
||||
N=DY.size(-1),
|
||||
K=X.size(-1),
|
||||
# ACC_TYPE: tl.constexpr,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_config_XtY(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
|
||||
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _groupXtY(
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyk,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xn,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
expert_offsets_ptr,
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
NO_N_MASK: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
num0 = tl.num_programs(0)
|
||||
num1 = tl.num_programs(1)
|
||||
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
|
||||
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
|
||||
|
||||
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
|
||||
E_idx = pid0 // K_BLOCK_COUNT
|
||||
K_block_id = pid0 % K_BLOCK_COUNT
|
||||
N_block_id = pid1
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
|
||||
if end_idx > start_idx:
|
||||
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
|
||||
|
||||
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
K_mask = K_block < K
|
||||
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
M_idxs = M_block
|
||||
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
|
||||
dy_blk_ptrs = (
|
||||
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
|
||||
)
|
||||
if (Db_ptr is not None) and (K_block_id == 0):
|
||||
_xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias=True,
|
||||
)
|
||||
else:
|
||||
_xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias=False,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias: tl.constexpr,
|
||||
):
|
||||
if compute_bias:
|
||||
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
|
||||
else:
|
||||
db_acc = None
|
||||
|
||||
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
|
||||
for i in range(0, iters):
|
||||
M_mask = (i * BLOCK_M + M_block) < end_idx
|
||||
if NO_K_MASK:
|
||||
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
|
||||
else:
|
||||
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
|
||||
if NO_N_MASK:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
|
||||
else:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
|
||||
|
||||
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
|
||||
|
||||
xt_blk_ptrs += BLOCK_M * stride_xm
|
||||
dy_blk_ptrs += BLOCK_M * stride_dym
|
||||
|
||||
if compute_bias:
|
||||
db_acc += tl.sum(dy, axis=0)
|
||||
|
||||
DW_blk_ptrs = (
|
||||
DW_ptr
|
||||
+ E_idx * stride_dwe
|
||||
+ K_block[:, None] * stride_dwk
|
||||
+ N_block[None, :] * stride_dwn
|
||||
)
|
||||
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
|
||||
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
|
||||
if compute_bias:
|
||||
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
|
||||
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
|
||||
|
||||
|
||||
def _config_grouping():
|
||||
return [
|
||||
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
|
||||
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
|
||||
N = sorted_expert_idxs.size(0)
|
||||
K = A.size(1)
|
||||
assert A.size(0) * fan_out == N
|
||||
if out is not None:
|
||||
Y = out
|
||||
else:
|
||||
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
|
||||
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
|
||||
return Y
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
|
||||
def group_compileable(
|
||||
A: torch.Tensor,
|
||||
K: int,
|
||||
N: int,
|
||||
Y: torch.Tensor,
|
||||
coeff: Optional[torch.Tensor],
|
||||
has_coeff: bool,
|
||||
fan_out: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
|
||||
return grid_num
|
||||
|
||||
_group[grid](
|
||||
# A_ptr, stride_an, stride_ai,
|
||||
A,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
has_coeff,
|
||||
coeff,
|
||||
fan_out,
|
||||
# Y_ptr, stride_yn, stride_yk,
|
||||
Y,
|
||||
Y.stride(0),
|
||||
Y.stride(1),
|
||||
# grouped_idx_ptr,
|
||||
sorted_expert_idxs,
|
||||
# N: tl.constexpr, K: tl.constexpr,
|
||||
N,
|
||||
K,
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(configs=_config_grouping(), key=["K"])
|
||||
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
|
||||
@triton.jit
|
||||
def _group(
|
||||
src_ptr,
|
||||
stride_sn,
|
||||
stride_sk,
|
||||
has_coeff: tl.constexpr,
|
||||
coeff_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
tgt_ptr,
|
||||
stride_tn,
|
||||
stride_ti,
|
||||
grouped_idx_ptr,
|
||||
N,
|
||||
K: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_block_id = pid
|
||||
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_blk < N
|
||||
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
|
||||
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
|
||||
|
||||
K_blk = tl.arange(0, BLOCK_K)
|
||||
src_blk_ptrs = (
|
||||
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
|
||||
)
|
||||
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
|
||||
|
||||
if has_coeff:
|
||||
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
|
||||
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for i in range(0, iters):
|
||||
if NO_K_MASK or i < iters - 1:
|
||||
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
|
||||
|
||||
else:
|
||||
K_mask = (i * BLOCK_K + K_blk) < K
|
||||
mask = N_mask[:, None] & K_mask[None, :]
|
||||
block = tl.load(src_blk_ptrs, mask=mask)
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=mask)
|
||||
src_blk_ptrs += BLOCK_K * stride_sk
|
||||
tgt_blk_ptrs += BLOCK_K * stride_ti
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _single2scatter(
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
Y_ptr,
|
||||
stride_ym,
|
||||
stride_yn,
|
||||
expert_idxs_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
E: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
|
||||
N_block_id = pid0
|
||||
if FAN_OUT == 1:
|
||||
in_idx = pid1
|
||||
else:
|
||||
in_idx = 0
|
||||
out_idx = pid1
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
N_block = tl.max_contiguous(
|
||||
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
|
||||
BLOCK_N,
|
||||
)
|
||||
E_idx = tl.load(expert_idxs_ptr + pid1)
|
||||
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
||||
W_blk_ptrs = (
|
||||
W_ptr
|
||||
+ E_idx * stride_we
|
||||
+ K_block[:, None] * stride_wk
|
||||
+ N_block[None, :] * stride_wn
|
||||
)
|
||||
N_mask = N_block < N
|
||||
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
||||
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
K_mask = K_block < K
|
||||
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
|
||||
acc += tl.sum(x * w, axis=0)[None, :]
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
K_block += BLOCK_K
|
||||
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
||||
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
|
||||
|
||||
|
||||
def single2scatter(X, W, expert_idxs):
|
||||
E, xdim, ydim = W.size()
|
||||
k = expert_idxs.size(1)
|
||||
assert X.size(0) == k or X.size(0) == 1
|
||||
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 128
|
||||
grid = triton.cdiv(ydim, BLOCK_N), k
|
||||
_single2scatter[grid](
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
W,
|
||||
W.stride(0),
|
||||
W.stride(1),
|
||||
W.stride(2),
|
||||
Y,
|
||||
Y.stride(0),
|
||||
Y.stride(1),
|
||||
expert_idxs,
|
||||
FAN_OUT=Y.size(0) // X.size(0),
|
||||
K=xdim,
|
||||
N=ydim,
|
||||
E=E,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
ACC_TYPE=tl.float32,
|
||||
)
|
||||
return Y
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user