Compare commits
37 Commits
coderabbit
...
dft
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a0115493d | ||
|
|
7a4f33802d | ||
|
|
170dca9bb9 | ||
|
|
d282f32481 | ||
|
|
6331e4a130 | ||
|
|
1410e4474e | ||
|
|
dc77b5bf42 | ||
|
|
359b7ad85e | ||
|
|
258ce8d4fa | ||
|
|
3e0bbd33ec | ||
|
|
4ae6f766ad | ||
|
|
e7f0d4ba5b | ||
|
|
7bf6f70e96 | ||
|
|
8aab807e67 | ||
|
|
ee59e4de97 | ||
|
|
4e61b8aa23 | ||
|
|
b26ba3a5cb | ||
|
|
afe18ace35 | ||
|
|
2b199f9915 | ||
|
|
e73dab6df9 | ||
|
|
f45a97a9ff | ||
|
|
11c0b5b256 | ||
|
|
66a3de3629 | ||
|
|
a6080df73c | ||
|
|
4f5e8a328a | ||
|
|
418933f0d1 | ||
|
|
372f664c63 | ||
|
|
97f1b1758d | ||
|
|
f2155eaf79 | ||
|
|
92ee4256f7 | ||
|
|
efeb5a4e41 | ||
|
|
faaff6c792 | ||
|
|
43cef27458 | ||
|
|
07c41a6c2a | ||
|
|
bbd3486f57 | ||
|
|
3750d7dd64 | ||
|
|
2197b0bf89 |
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,6 +15,11 @@
|
|||||||
<!--- Include details of your testing environment, tests ran to see how -->
|
<!--- Include details of your testing environment, tests ran to see how -->
|
||||||
<!--- your change affects other areas of the code, etc. -->
|
<!--- your change affects other areas of the code, etc. -->
|
||||||
|
|
||||||
|
## AI Usage Disclaimer
|
||||||
|
|
||||||
|
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
||||||
|
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
||||||
|
|
||||||
## Screenshots (if appropriate)
|
## Screenshots (if appropriate)
|
||||||
|
|
||||||
## Types of changes
|
## Types of changes
|
||||||
|
|||||||
65
.github/workflows/base.yml
vendored
65
.github/workflows/base.yml
vendored
@@ -21,31 +21,12 @@ jobs:
|
|||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
|
env:
|
||||||
|
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "126"
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
- cuda: "126"
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -53,6 +34,15 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -60,6 +50,7 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -67,6 +58,7 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
# cuda_version: 12.8.1
|
# cuda_version: 12.8.1
|
||||||
# cudnn_version: ""
|
# cudnn_version: ""
|
||||||
@@ -93,6 +85,7 @@ jobs:
|
|||||||
axolotlai/axolotl-base
|
axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
|
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
@@ -103,6 +96,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/${{ matrix.dockerfile }}
|
file: ./docker/${{ matrix.dockerfile }}
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
@@ -117,24 +111,12 @@ jobs:
|
|||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
|
env:
|
||||||
|
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "126"
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -142,6 +124,7 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -149,6 +132,15 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -156,6 +148,7 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -167,6 +160,7 @@ jobs:
|
|||||||
axolotlai/axolotl-base-uv
|
axolotlai/axolotl-base-uv
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
|
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
@@ -177,6 +171,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/${{ matrix.dockerfile }}
|
file: ./docker/${{ matrix.dockerfile }}
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,6 +12,9 @@ jobs:
|
|||||||
build-deploy:
|
build-deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: cleanup node
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
|
|||||||
81
.github/workflows/main.yml
vendored
81
.github/workflows/main.yml
vendored
@@ -15,37 +15,31 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
platforms: "linux/amd64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
is_latest: true
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -71,6 +65,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
@@ -92,43 +87,31 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
is_latest:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
platforms: "linux/amd64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -153,6 +136,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
platforms: ${{ matrix.platforms }}
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
@@ -170,22 +154,16 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
is_latest:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
is_latest: true
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest:
|
is_latest:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
@@ -212,6 +190,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
|
|||||||
23
.github/workflows/multi-gpu-e2e.yml
vendored
23
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -19,6 +19,9 @@ concurrency:
|
|||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
|
env:
|
||||||
|
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
@@ -26,13 +29,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
num_gpus: 2
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -43,7 +39,14 @@ jobs:
|
|||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras: fbgemm-gpu
|
||||||
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
axolotl_extras: fbgemm-gpu
|
axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
@@ -59,7 +62,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -72,4 +75,4 @@ jobs:
|
|||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.multigpu
|
modal run -m cicd.multigpu
|
||||||
|
|||||||
20
.github/workflows/nightlies.yml
vendored
20
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -64,16 +64,16 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,6 +11,7 @@ on:
|
|||||||
- '_quarto.yml'
|
- '_quarto.yml'
|
||||||
- docs/scripts/generate_config_docs.py
|
- docs/scripts/generate_config_docs.py
|
||||||
- src/axolotl/utils/schemas/**.py
|
- src/axolotl/utils/schemas/**.py
|
||||||
|
- .github/workflows/preview-docs.yml
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
checks: write
|
checks: write
|
||||||
@@ -27,6 +28,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
|
- name: cleanup node
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||||
|
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
20
.github/workflows/tests-nightly.yml
vendored
20
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.7.1", "2.8.0"]
|
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -99,17 +99,17 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 128
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.8.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
@@ -123,7 +123,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -148,10 +148,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 128
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
@@ -165,7 +165,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
75
.github/workflows/tests.yml
vendored
75
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -66,12 +66,13 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# - name: Restore Cache from S3
|
- name: Restore Cache from S3
|
||||||
# id: hf-cache-restore-s3
|
id: hf-cache-restore-s3
|
||||||
# run: |
|
run: |
|
||||||
# mkdir -p ~/.cache/huggingface/hub
|
mkdir -p ~/.cache/huggingface/hub
|
||||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||||
#
|
ls -ltr ~/.cache/huggingface/hub/
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -111,10 +112,13 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: hf cache scan
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
@@ -122,6 +126,9 @@ jobs:
|
|||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: hf cache scan
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
@@ -138,7 +145,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -149,12 +156,13 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
# - name: Restore Cache from S3
|
- name: Restore Cache from S3
|
||||||
# id: hf-cache-restore-s3
|
id: hf-cache-restore-s3
|
||||||
# run: |
|
run: |
|
||||||
# mkdir -p ~/.cache/huggingface/hub
|
mkdir -p ~/.cache/huggingface/hub
|
||||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||||
#
|
ls -ltr ~/.cache/huggingface/hub/
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -196,10 +204,13 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: hf cache scan
|
||||||
|
|
||||||
gate-skip-e2e:
|
gate-skip-e2e:
|
||||||
needs: [pre-commit, pytest, pytest-sdist]
|
needs: [pre-commit, pytest, pytest-sdist]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -260,7 +271,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -292,18 +303,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
# - cuda: 128
|
|
||||||
# cuda_version: 12.8.1
|
|
||||||
# python_version: "3.11"
|
|
||||||
# pytorch: 2.7.1
|
|
||||||
# num_gpus: 1
|
|
||||||
# axolotl_extras:
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -314,7 +313,13 @@ jobs:
|
|||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 130
|
||||||
|
cuda_version: 13.0.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
@@ -327,7 +332,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -354,10 +359,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 128
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
@@ -370,7 +375,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.0.2 jinja2
|
pip install modal==1.3.0.post1 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.14.7
|
rev: v0.14.10
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.19.0
|
rev: v1.19.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -29,15 +29,15 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
- 2025/07:
|
- 2025/07:
|
||||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
||||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
|
|
||||||
@@ -46,8 +46,8 @@
|
|||||||
<summary>Expand older updates</summary>
|
<summary>Expand older updates</summary>
|
||||||
|
|
||||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||||
@@ -77,7 +77,7 @@ Features:
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python 3.11
|
||||||
- PyTorch ≥2.7.1
|
- PyTorch ≥2.8.0
|
||||||
|
|
||||||
### Google Colab
|
### Google Colab
|
||||||
|
|
||||||
|
|||||||
44
_quarto.yml
44
_quarto.yml
@@ -1,6 +1,8 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
pre-render: docs/scripts/generate_config_docs.py
|
pre-render:
|
||||||
|
- docs/scripts/generate_config_docs.py
|
||||||
|
- docs/scripts/generate_examples_docs.py
|
||||||
|
|
||||||
quartodoc:
|
quartodoc:
|
||||||
dir: docs/api
|
dir: docs/api
|
||||||
@@ -240,6 +242,46 @@ website:
|
|||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
|
- section: "Model Guides"
|
||||||
|
contents:
|
||||||
|
- docs/models/kimi-linear.qmd
|
||||||
|
- docs/models/plano.qmd
|
||||||
|
- docs/models/mimo.qmd
|
||||||
|
- docs/models/internvl3_5.qmd
|
||||||
|
- docs/models/olmo3.qmd
|
||||||
|
- docs/models/trinity.qmd
|
||||||
|
- docs/models/arcee.qmd
|
||||||
|
- docs/models/mistral.qmd
|
||||||
|
- section: "Ministral3"
|
||||||
|
contents:
|
||||||
|
- docs/models/ministral3.qmd
|
||||||
|
- docs/models/ministral3/think.qmd
|
||||||
|
- docs/models/ministral3/vision.qmd
|
||||||
|
- section: "Magistral"
|
||||||
|
contents:
|
||||||
|
- docs/models/magistral.qmd
|
||||||
|
- docs/models/magistral/think.qmd
|
||||||
|
- docs/models/magistral/vision.qmd
|
||||||
|
- docs/models/ministral.qmd
|
||||||
|
- docs/models/mistral-small.qmd
|
||||||
|
- docs/models/voxtral.qmd
|
||||||
|
- docs/models/devstral.qmd
|
||||||
|
- docs/models/llama-4.qmd
|
||||||
|
- docs/models/llama-2.qmd
|
||||||
|
- docs/models/qwen3-next.qmd
|
||||||
|
- docs/models/qwen3.qmd
|
||||||
|
- docs/models/gemma3n.qmd
|
||||||
|
- docs/models/apertus.qmd
|
||||||
|
- docs/models/gpt-oss.qmd
|
||||||
|
- docs/models/seed-oss.qmd
|
||||||
|
- docs/models/phi.qmd
|
||||||
|
- docs/models/smolvlm2.qmd
|
||||||
|
- docs/models/granite4.qmd
|
||||||
|
- docs/models/LiquidAI.qmd
|
||||||
|
- docs/models/hunyuan.qmd
|
||||||
|
- docs/models/jamba.qmd
|
||||||
|
- docs/models/orpheus.qmd
|
||||||
|
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/telemetry.qmd
|
- docs/telemetry.qmd
|
||||||
- docs/config-reference.qmd
|
- docs/config-reference.qmd
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||||
pytest -v --durations=10 -n2 \
|
pytest -v --durations=10 -n2 --maxfail=4 \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
|
|||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
|
ARG TARGETARCH
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
python scripts/unsloth_install.py | sh && \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
|
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
else \
|
||||||
|
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
fi && \ python scripts/unsloth_install.py | sh && \
|
||||||
python scripts/cutcrossentropy_install.py | sh && \
|
python scripts/cutcrossentropy_install.py | sh && \
|
||||||
pip install pytest && \
|
pip install pytest && \
|
||||||
pip cache purge
|
pip cache purge
|
||||||
|
|||||||
@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
|
|||||||
ARG CUDNN_VERSION="8"
|
ARG CUDNN_VERSION="8"
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
ARG TARGETARCH
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.10"
|
ARG TARGETARCH
|
||||||
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG CUDA="118"
|
ARG CUDA="128"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
@@ -22,11 +24,17 @@ RUN apt-get update \
|
|||||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||||
&& rm -rf /var/cache/apt/archives \
|
&& rm -rf /var/cache/apt/archives \
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
MINICONDA_ARCH="x86_64"; \
|
||||||
|
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
|
MINICONDA_ARCH="aarch64"; \
|
||||||
|
else \
|
||||||
|
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||||
|
fi \
|
||||||
|
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
@@ -51,8 +59,34 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
RUN case "$PYTORCH_VERSION" in \
|
||||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
2.9.[0-9]*) \
|
||||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
if [ "$CUDA" = "128" ]; then \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
fi
|
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||||
|
WHL_VERSION="v0.5.4"; \
|
||||||
|
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
|
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||||
|
WHL_VERSION="v0.6.4"; \
|
||||||
|
else \
|
||||||
|
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||||
|
fi; \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||||
|
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||||
|
rm ${WHL_FILE}; \
|
||||||
|
elif [ "$CUDA" = "130" ]; then \
|
||||||
|
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
|
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||||
|
WHL_VERSION="v0.5.4"; \
|
||||||
|
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
|
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||||
|
WHL_VERSION="v0.6.4"; \
|
||||||
|
else \
|
||||||
|
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||||
|
fi; \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||||
|
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||||
|
rm ${WHL_FILE}; \
|
||||||
|
fi \
|
||||||
|
;; \
|
||||||
|
esac
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ ARG CUDA_VERSION="12.6.3"
|
|||||||
ARG CUDNN_VERSION=""
|
ARG CUDNN_VERSION=""
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
ARG TARGETARCH
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
@@ -31,12 +32,35 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
|||||||
|
|
||||||
RUN uv pip install packaging setuptools wheel psutil \
|
RUN uv pip install packaging setuptools wheel psutil \
|
||||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
RUN case "$PYTORCH_VERSION" in \
|
||||||
|
2.9.[0-9]*) \
|
||||||
|
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
|
if [ "$CUDA" = "128" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$CUDA" = "130" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi \
|
||||||
|
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
|
if [ "$CUDA" = "128" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
elif [ "$CUDA" = "130" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||||
|
fi \
|
||||||
|
fi \
|
||||||
|
;; \
|
||||||
|
esac
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -3,3 +3,5 @@ _site/
|
|||||||
/api/*.qmd
|
/api/*.qmd
|
||||||
/api/*.html
|
/api/*.html
|
||||||
config-reference.qmd
|
config-reference.qmd
|
||||||
|
models/**/*.qmd
|
||||||
|
models/**/*.html
|
||||||
|
|||||||
86
docs/checkpoint_saving.qmd
Normal file
86
docs/checkpoint_saving.qmd
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
---
|
||||||
|
title: "Checkpoint Saving"
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 2
|
||||||
|
number-sections: true
|
||||||
|
execute:
|
||||||
|
enabled: false
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
||||||
|
|
||||||
|
## File-Based Checkpoint Trigger
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Enable in your config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dynamic_checkpoint:
|
||||||
|
enabled: true
|
||||||
|
check_interval: 100 # Optional: check every N steps (default: 100)
|
||||||
|
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `enabled`: `true` to enable (required)
|
||||||
|
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
||||||
|
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
||||||
|
2. When detected, file is deleted and checkpoint is saved
|
||||||
|
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
**Command line:**
|
||||||
|
```bash
|
||||||
|
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||||
|
```
|
||||||
|
|
||||||
|
**Programmatic:**
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
||||||
|
```
|
||||||
|
|
||||||
|
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
||||||
|
|
||||||
|
**Custom filename:**
|
||||||
|
```yaml
|
||||||
|
dynamic_checkpoint:
|
||||||
|
enabled: true
|
||||||
|
trigger_file_path: "my_trigger.save"
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
touch /path/to/output_dir/my_trigger.save
|
||||||
|
```
|
||||||
|
|
||||||
|
## Control+C (SIGINT) Checkpoint
|
||||||
|
|
||||||
|
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
||||||
|
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
||||||
|
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
save_steps: 500 # Scheduled checkpoints
|
||||||
|
|
||||||
|
dynamic_checkpoint:
|
||||||
|
enabled: true
|
||||||
|
check_interval: 50
|
||||||
|
```
|
||||||
|
|
||||||
|
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
||||||
@@ -32,11 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.8.0`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu128-2.9.1`
|
||||||
- `main-base-py3.11-cu126-2.7.0`
|
|
||||||
- `main-base-py3.11-cu126-2.6.0`
|
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -74,15 +71,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu128-2.7.1`
|
- `main-py3.11-cu128-2.8.0`
|
||||||
- `main-py3.11-cu126-2.7.1`
|
- `main-py3.11-cu128-2.9.1`
|
||||||
- `main-py3.11-cu126-2.7.0`
|
|
||||||
- `main-py3.11-cu126-2.6.0`
|
|
||||||
- `main-py3.11-cu124-2.6.0`
|
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu126-2.6.0`
|
- `main-20250303-py3.11-cu126-2.6.0`
|
||||||
- `0.10.1`
|
- `0.12.0`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ format:
|
|||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
- [SmolVLM2](#sec-smolvlm2)
|
- [SmolVLM2](#sec-smolvlm2)
|
||||||
- [LFM2-VL](#sec-lfm2-vl)
|
- [LFM2-VL](#sec-lfm2-vl)
|
||||||
|
- [Intern-VL](#sec-intern-vl)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -202,6 +203,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
|||||||
base_model: LiquidAI/LFM2-VL-450M
|
base_model: LiquidAI/LFM2-VL-450M
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Intern-VL {#sec-intern-vl}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Please make sure to install `timm` via `pip3 install timm==1.0.19`
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: OpenGVLab/InternVL3_5-8B
|
||||||
|
```
|
||||||
|
|
||||||
## Dataset Format
|
## Dataset Format
|
||||||
|
|
||||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
|
|||||||
90
docs/scripts/examples-allowlist.yml
Normal file
90
docs/scripts/examples-allowlist.yml
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
examples:
|
||||||
|
# December 2025
|
||||||
|
- name: kimi-linear
|
||||||
|
title: Kimi Linear
|
||||||
|
- name: plano
|
||||||
|
title: Plano Orchestrator
|
||||||
|
- name: mimo
|
||||||
|
title: MiMo
|
||||||
|
- name: internvl3_5
|
||||||
|
title: InternVL 3.5
|
||||||
|
|
||||||
|
# AllenAI
|
||||||
|
- name: olmo3
|
||||||
|
title: OLMo 3
|
||||||
|
|
||||||
|
# ArceeAI
|
||||||
|
- name: trinity
|
||||||
|
title: Trinity
|
||||||
|
- name: arcee
|
||||||
|
title: Arcee AFM
|
||||||
|
|
||||||
|
# MistralAI
|
||||||
|
- name: ministral3/think
|
||||||
|
title: Ministral 3 Thinking
|
||||||
|
- name: ministral3/vision
|
||||||
|
title: Ministral 3 Vision
|
||||||
|
- name: magistral/think
|
||||||
|
title: Magistral Thinking
|
||||||
|
- name: magistral/vision
|
||||||
|
title: Magistral Vision
|
||||||
|
- name: ministral
|
||||||
|
title: Ministral
|
||||||
|
- name: mistral-small
|
||||||
|
title: Mistral Small 3.1/3.2
|
||||||
|
- name: voxtral
|
||||||
|
title: Voxtral
|
||||||
|
- name: devstral
|
||||||
|
title: Devstral
|
||||||
|
- name: mistral
|
||||||
|
title: Mistral 7B
|
||||||
|
|
||||||
|
# Meta
|
||||||
|
- name: llama-4
|
||||||
|
title: Llama 4
|
||||||
|
- name: llama-2
|
||||||
|
title: Llama 2
|
||||||
|
|
||||||
|
# Alibaba
|
||||||
|
- name: qwen3-next
|
||||||
|
title: Qwen 3 Next
|
||||||
|
- name: qwen3
|
||||||
|
title: Qwen 3
|
||||||
|
|
||||||
|
# Google
|
||||||
|
- name: gemma3n
|
||||||
|
title: Gemma 3n
|
||||||
|
|
||||||
|
# Swiss AI
|
||||||
|
- name: apertus
|
||||||
|
title: Apertus
|
||||||
|
|
||||||
|
# GPT-OSS
|
||||||
|
- name: gpt-oss
|
||||||
|
title: GPT-OSS
|
||||||
|
- name: seed-oss
|
||||||
|
title: Seed-OSS
|
||||||
|
|
||||||
|
# Microsoft
|
||||||
|
- name: phi
|
||||||
|
title: Phi
|
||||||
|
|
||||||
|
# SmolVLM
|
||||||
|
- name: smolvlm2
|
||||||
|
title: SmolVLM 2
|
||||||
|
|
||||||
|
# IBM
|
||||||
|
- name: granite4
|
||||||
|
title: Granite 4
|
||||||
|
|
||||||
|
# LiquidAI
|
||||||
|
- name: LiquidAI
|
||||||
|
title: Liquid Foundation Models 2
|
||||||
|
|
||||||
|
# Other
|
||||||
|
- name: hunyuan
|
||||||
|
title: Hunyuan
|
||||||
|
- name: jamba
|
||||||
|
title: Jamba
|
||||||
|
- name: orpheus
|
||||||
|
title: Orpheus
|
||||||
424
docs/scripts/generate_examples_docs.py
Executable file
424
docs/scripts/generate_examples_docs.py
Executable file
@@ -0,0 +1,424 @@
|
|||||||
|
"""
|
||||||
|
auto generate example docs from allowlist
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
THIS = Path(__file__).resolve()
|
||||||
|
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
|
||||||
|
EXAMPLES_DIR = ROOT / "examples"
|
||||||
|
OUTPUT_DIR = ROOT / "docs" / "models"
|
||||||
|
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
|
||||||
|
|
||||||
|
|
||||||
|
def slugify(name: str) -> str:
|
||||||
|
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
|
||||||
|
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
|
||||||
|
s = re.sub(r"\s+", "-", s).strip("-").lower()
|
||||||
|
return s or "example"
|
||||||
|
|
||||||
|
|
||||||
|
def read_allowlist():
|
||||||
|
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f) or {}
|
||||||
|
items = data.get("examples", [])
|
||||||
|
if not isinstance(items, list):
|
||||||
|
raise ValueError("`examples` must be a list in examples-allowlist.yml")
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def find_readme(folder: Path) -> Path | None:
|
||||||
|
for name in ("README.md", "Readme.md", "readme.md"):
|
||||||
|
p = folder / name
|
||||||
|
if p.exists():
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remove_first_h1(md: str) -> tuple[str, str | None]:
|
||||||
|
"""
|
||||||
|
Remove the first H1 from markdown and return (modified_md, h1_title).
|
||||||
|
The H1 is removed since we use the frontmatter title instead.
|
||||||
|
"""
|
||||||
|
lines = md.splitlines()
|
||||||
|
result = []
|
||||||
|
h1_title = None
|
||||||
|
skipped_first = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if not skipped_first and line.startswith("# "):
|
||||||
|
h1_title = line[2:].strip()
|
||||||
|
skipped_first = True
|
||||||
|
continue
|
||||||
|
result.append(line)
|
||||||
|
|
||||||
|
return "\n".join(result), h1_title
|
||||||
|
|
||||||
|
|
||||||
|
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
|
||||||
|
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
|
||||||
|
"""
|
||||||
|
Copy local image assets referenced in markdown to
|
||||||
|
docs/examples/assets/... and rewrite the links.
|
||||||
|
"""
|
||||||
|
dest_assets = dest_assets_root / "assets"
|
||||||
|
|
||||||
|
def repl(m):
|
||||||
|
url = m.group(1).strip()
|
||||||
|
if re.match(r"^(https?:)?//", url):
|
||||||
|
return m.group(0) # leave remote URLs
|
||||||
|
src_path = (src_dir / url).resolve()
|
||||||
|
if not src_path.exists():
|
||||||
|
return m.group(0) # leave as-is if not found
|
||||||
|
rel = src_path.relative_to(src_dir)
|
||||||
|
# Create a unique asset path based on source directory name
|
||||||
|
asset_name = src_dir.name.replace("/", "-")
|
||||||
|
dest_path = dest_assets / asset_name / rel
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(src_path, dest_path)
|
||||||
|
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
|
||||||
|
return m.group(0).replace(url, new_rel)
|
||||||
|
|
||||||
|
return IMG_RE.sub(repl, md)
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_readme_links(
|
||||||
|
md: str,
|
||||||
|
src_dir: Path,
|
||||||
|
examples_dir: Path,
|
||||||
|
parent_index_only: set,
|
||||||
|
current_src_path: str,
|
||||||
|
allowlist_entries: set,
|
||||||
|
current_output_path: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Rewrite links between README.md files to point to the correct .qmd files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def repl(m):
|
||||||
|
text = m.group(1)
|
||||||
|
url = m.group(2).strip()
|
||||||
|
|
||||||
|
# Skip remote URLs and anchor links
|
||||||
|
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
|
||||||
|
return m.group(0)
|
||||||
|
|
||||||
|
# Skip non-markdown files
|
||||||
|
if not url.lower().endswith(".md"):
|
||||||
|
return m.group(0)
|
||||||
|
|
||||||
|
# Resolve the target path
|
||||||
|
try:
|
||||||
|
target_path = (src_dir / url).resolve()
|
||||||
|
|
||||||
|
# Check if target is outside examples_dir
|
||||||
|
try:
|
||||||
|
rel_path = target_path.relative_to(examples_dir)
|
||||||
|
except ValueError:
|
||||||
|
# Target is outside examples_dir, leave as-is
|
||||||
|
return m.group(0)
|
||||||
|
|
||||||
|
parts = list(rel_path.parts)
|
||||||
|
|
||||||
|
# Determine the output path for the target
|
||||||
|
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
|
||||||
|
# This is a README link
|
||||||
|
if len(parts) == 1:
|
||||||
|
# Link to root README -> index.qmd
|
||||||
|
target_output = "index.qmd"
|
||||||
|
elif len(parts) == 2:
|
||||||
|
if parts[0] == ".":
|
||||||
|
# Current directory README
|
||||||
|
target_output = "index.qmd"
|
||||||
|
else:
|
||||||
|
# subdir/README.md
|
||||||
|
parent_dir = parts[0]
|
||||||
|
if parent_dir in parent_index_only:
|
||||||
|
target_output = f"{parent_dir}/index.qmd"
|
||||||
|
else:
|
||||||
|
target_output = f"{parent_dir}.qmd"
|
||||||
|
else:
|
||||||
|
# Deeper nesting: parent/subdir/README.md
|
||||||
|
# Build the full path like "parent/subdir"
|
||||||
|
full_path = "/".join(parts[:-1]) # Remove README.md
|
||||||
|
# Check if this exact path is in allowlist
|
||||||
|
if full_path in allowlist_entries:
|
||||||
|
# This is a sub-entry with its own entry -> use .qmd
|
||||||
|
target_output = f"{full_path}.qmd"
|
||||||
|
elif parts[0] == ".":
|
||||||
|
# ./subdir/README.md -> check if subdir has own entry
|
||||||
|
subdir = parts[1]
|
||||||
|
if subdir in parent_index_only:
|
||||||
|
target_output = f"{subdir}/index.qmd"
|
||||||
|
else:
|
||||||
|
target_output = f"{subdir}.qmd"
|
||||||
|
else:
|
||||||
|
# parent/subdir where parent doesn't have own entry
|
||||||
|
target_output = f"{full_path}/index.qmd"
|
||||||
|
else:
|
||||||
|
# Regular .md file -> convert to .qmd, keep path structure
|
||||||
|
target_output = "/".join(parts)[:-2] + "qmd"
|
||||||
|
|
||||||
|
# Compute relative path from current output file to target
|
||||||
|
current_parts = current_output_path.split("/")
|
||||||
|
target_parts = target_output.split("/")
|
||||||
|
|
||||||
|
# Special case: if current is a subdir file and target is a single-component file at root
|
||||||
|
# Example: current="magistral/vision", target="magistral.qmd"
|
||||||
|
if len(current_parts) > 1 and len(target_parts) == 1:
|
||||||
|
# Current is in subdir, target is at root level
|
||||||
|
# Go up to root: ../ for each level
|
||||||
|
up_count = len(current_parts) - 1
|
||||||
|
rel_parts = [".."] * up_count + [target_parts[0]]
|
||||||
|
new_url = "/".join(rel_parts)
|
||||||
|
else:
|
||||||
|
# Find common prefix
|
||||||
|
i = 0
|
||||||
|
while (
|
||||||
|
i < min(len(current_parts) - 1, len(target_parts))
|
||||||
|
and current_parts[i] == target_parts[i]
|
||||||
|
):
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Build relative path: go up (../) then down to target
|
||||||
|
up_count = len(current_parts) - 1 - i
|
||||||
|
rel_parts = [".."] * up_count + target_parts[i:]
|
||||||
|
|
||||||
|
if not rel_parts or rel_parts == [".."]:
|
||||||
|
# Points to same directory or parent
|
||||||
|
new_url = "/".join(rel_parts) if rel_parts else "."
|
||||||
|
else:
|
||||||
|
new_url = "/".join(rel_parts)
|
||||||
|
|
||||||
|
return f"[{text}]({new_url})"
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return m.group(0)
|
||||||
|
|
||||||
|
return LINK_RE.sub(repl, md)
|
||||||
|
|
||||||
|
|
||||||
|
def write_qmd(out_path: Path, title: str, body_md: str):
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||||
|
out_path.write_text(fm + body_md, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def update_quarto_yml(generated: list[tuple[str, str, str]]):
|
||||||
|
"""
|
||||||
|
Update _quarto.yml with the generated example files in the correct order.
|
||||||
|
This keeps the sidebar in sync with the allowlist.
|
||||||
|
|
||||||
|
Model Guides is now nested under "Getting Started" section.
|
||||||
|
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
|
||||||
|
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
|
||||||
|
"""
|
||||||
|
quarto_yml = ROOT / "_quarto.yml"
|
||||||
|
if not quarto_yml.exists():
|
||||||
|
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
content = quarto_yml.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# First pass: find all parents that have sub-entries
|
||||||
|
parents_with_subs = set()
|
||||||
|
for path, _name, _title in generated:
|
||||||
|
if "/" in path:
|
||||||
|
parent = path.split("/")[0]
|
||||||
|
parents_with_subs.add(parent)
|
||||||
|
|
||||||
|
# Build the YAML contents while preserving allowlist order
|
||||||
|
lines = []
|
||||||
|
processed_sections = set()
|
||||||
|
|
||||||
|
for path, _name, title in generated:
|
||||||
|
# Check if this is a parent page that has sub-pages
|
||||||
|
if path in parents_with_subs:
|
||||||
|
# This is a parent page with sub-pages - create a nested section
|
||||||
|
if path not in processed_sections:
|
||||||
|
processed_sections.add(path)
|
||||||
|
section_title = (
|
||||||
|
title or path.replace("-", " ").replace("_", " ").title()
|
||||||
|
)
|
||||||
|
lines.append(f' - section: "{section_title}"')
|
||||||
|
lines.append(" contents:")
|
||||||
|
# Add the parent page first
|
||||||
|
lines.append(f" - docs/models/{path}.qmd")
|
||||||
|
# Then add all sub-pages
|
||||||
|
for sub_path, _sub_name, _sub_title in generated:
|
||||||
|
if "/" in sub_path and sub_path.split("/")[0] == path:
|
||||||
|
lines.append(
|
||||||
|
f" - docs/models/{sub_path}.qmd"
|
||||||
|
)
|
||||||
|
elif "/" not in path:
|
||||||
|
# This is a flat item with no sub-pages
|
||||||
|
# Skip if it was already included as part of a parent section
|
||||||
|
if path not in processed_sections:
|
||||||
|
lines.append(f" - docs/models/{path}.qmd")
|
||||||
|
|
||||||
|
yaml_content = "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
# Pattern to match only the Model Guides contents, stopping at the next item
|
||||||
|
# in Getting Started (lines starting with 12 spaces: same level as the section)
|
||||||
|
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
|
||||||
|
|
||||||
|
def replacement(match):
|
||||||
|
prefix = match.group(1)
|
||||||
|
return prefix + "\n" + yaml_content
|
||||||
|
|
||||||
|
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
if new_content != content:
|
||||||
|
quarto_yml.write_text(new_content, encoding="utf-8")
|
||||||
|
print(f"Updated {quarto_yml}")
|
||||||
|
else:
|
||||||
|
print(f"No changes needed for {quarto_yml}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
allow = read_allowlist()
|
||||||
|
if not EXAMPLES_DIR.exists():
|
||||||
|
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# First pass: identify which parents have their own entry vs only sub-entries
|
||||||
|
parent_entries = set() # Parents that have their own entry
|
||||||
|
parent_with_subs = set() # Parents that have sub-entries
|
||||||
|
allowlist_entries = set() # All entries in allowlist
|
||||||
|
|
||||||
|
for item in allow:
|
||||||
|
if isinstance(item, str):
|
||||||
|
name = item
|
||||||
|
else:
|
||||||
|
name = item.get("name")
|
||||||
|
|
||||||
|
allowlist_entries.add(name)
|
||||||
|
|
||||||
|
if "/" in name:
|
||||||
|
parent = name.split("/")[0]
|
||||||
|
parent_with_subs.add(parent)
|
||||||
|
else:
|
||||||
|
parent_entries.add(name)
|
||||||
|
|
||||||
|
# Parents with subs that DON'T have their own entry -> use index.qmd
|
||||||
|
parent_index_only = parent_with_subs - parent_entries
|
||||||
|
|
||||||
|
generated = []
|
||||||
|
seen_dirs = set() # Track which parent directories we've created index for
|
||||||
|
|
||||||
|
for item in allow:
|
||||||
|
if isinstance(item, str):
|
||||||
|
name = item
|
||||||
|
title = None
|
||||||
|
else:
|
||||||
|
name = item.get("name")
|
||||||
|
title = item.get("title")
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
src_dir = EXAMPLES_DIR / name
|
||||||
|
if not src_dir.exists() or not src_dir.is_dir():
|
||||||
|
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
readme = find_readme(src_dir)
|
||||||
|
if not readme:
|
||||||
|
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
md = readme.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Determine output path first (needed for link rewriting)
|
||||||
|
parts = name.split("/")
|
||||||
|
if len(parts) == 1:
|
||||||
|
# Simple case: no subdirectory
|
||||||
|
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
|
||||||
|
sidebar_path = parts[0]
|
||||||
|
else:
|
||||||
|
# Has subdirectory: e.g., magistral/think
|
||||||
|
parent = parts[0]
|
||||||
|
child = "-".join(parts[1:]) # handle nested subdirs
|
||||||
|
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
|
||||||
|
sidebar_path = f"{parent}/{child}"
|
||||||
|
|
||||||
|
# Remove the first H1 (we use frontmatter title instead)
|
||||||
|
md, _ = remove_first_h1(md)
|
||||||
|
# Rewrite links between README files
|
||||||
|
md = rewrite_readme_links(
|
||||||
|
md,
|
||||||
|
src_dir,
|
||||||
|
EXAMPLES_DIR,
|
||||||
|
parent_index_only,
|
||||||
|
name,
|
||||||
|
allowlist_entries,
|
||||||
|
sidebar_path,
|
||||||
|
)
|
||||||
|
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
|
||||||
|
|
||||||
|
# Handle parent page generation for sub-entries
|
||||||
|
if len(parts) > 1:
|
||||||
|
# Has subdirectory: e.g., magistral/think
|
||||||
|
parent = parts[0]
|
||||||
|
|
||||||
|
# Create parent.qmd if not already done and parent doesn't have own entry
|
||||||
|
if parent not in seen_dirs and parent in parent_index_only:
|
||||||
|
parent_readme = find_readme(EXAMPLES_DIR / parent)
|
||||||
|
if parent_readme:
|
||||||
|
parent_md = parent_readme.read_text(encoding="utf-8")
|
||||||
|
parent_md, _ = remove_first_h1(parent_md)
|
||||||
|
parent_md = rewrite_readme_links(
|
||||||
|
parent_md,
|
||||||
|
EXAMPLES_DIR / parent,
|
||||||
|
EXAMPLES_DIR,
|
||||||
|
parent_index_only,
|
||||||
|
parent,
|
||||||
|
allowlist_entries,
|
||||||
|
parent,
|
||||||
|
)
|
||||||
|
parent_md = rewrite_and_copy_assets(
|
||||||
|
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
|
||||||
|
)
|
||||||
|
parent_title = parent.replace("-", " ").replace("_", " ").title()
|
||||||
|
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
|
||||||
|
generated.append((parent, parent, parent_title))
|
||||||
|
seen_dirs.add(parent)
|
||||||
|
|
||||||
|
if not title:
|
||||||
|
title = name.replace("/", " ").replace("-", " ").title()
|
||||||
|
|
||||||
|
write_qmd(out_path, title, md)
|
||||||
|
generated.append((sidebar_path, name, title))
|
||||||
|
|
||||||
|
# Index page - preserve allowlist order
|
||||||
|
if generated:
|
||||||
|
listing = "\n".join(
|
||||||
|
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
|
||||||
|
)
|
||||||
|
index_md = (
|
||||||
|
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
|
||||||
|
+ listing
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
index_fm = (
|
||||||
|
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||||
|
)
|
||||||
|
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
|
||||||
|
|
||||||
|
# Auto-update _quarto.yml to keep sidebar in sync
|
||||||
|
update_quarto_yml(generated)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ gradient_checkpointing: true
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
scaling_softmax: true
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
loss_watchdog_threshold: 5.0
|
||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|||||||
53
examples/gemma3/gemma-3-1b-fft-dft.yml
Normal file
53
examples/gemma3/gemma-3-1b-fft-dft.yml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
|
||||||
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/gemma-3-1b-fft-dft
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
use_dynamic_finetuning: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
base_model: google/gemma-3-1b-it
|
base_model: google/gemma-3-1b-it
|
||||||
|
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
|||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
base_model: google/gemma-3-270m-it
|
base_model: google/gemma-3-270m-it
|
||||||
|
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
|
|||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
|
|||||||
|
|
||||||
# Need to set else transformers tries to load vision too
|
# Need to set else transformers tries to load vision too
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
|
cls_model_config: Gemma3TextConfig
|
||||||
|
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
@@ -32,8 +33,8 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_linear: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
|||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -29,6 +29,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
|
trackio_project_name:
|
||||||
|
trackio_run_name:
|
||||||
|
trackio_space_id:
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
43
examples/internvl3_5/README.md
Normal file
43
examples/internvl3_5/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Finetune OpenGV's InternVL with Axolotl
|
||||||
|
|
||||||
|
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install `timm` for vision model support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install timm==1.0.19
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
4. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: OpenGVLab/InternVL3_5-8B-HF
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
47
examples/kimi-linear/README.md
Normal file
47
examples/kimi-linear/README.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
||||||
|
|
||||||
|
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 98.7GiB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning!
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- Kimi Linear requires `trust_remote_code: true`.
|
||||||
|
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
This is not yet compatible with MoE kernels from transformers v5.
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
||||||
|
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.2
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -29,7 +29,6 @@ flex_attention: true
|
|||||||
flex_attn_compile_kwargs:
|
flex_attn_compile_kwargs:
|
||||||
dynamic: false
|
dynamic: false
|
||||||
mode: max-autotune-no-cudagraphs
|
mode: max-autotune-no-cudagraphs
|
||||||
save_strategy: no
|
|
||||||
torch_compile: true
|
torch_compile: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [main README](../README.md))
|
- Installed Axolotl (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
|
||||||
|
- Installed Axolotl from source (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
|||||||
39
examples/mimo/README.md
Normal file
39
examples/mimo/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Finetune Xiaomi's MiMo with Axolotl
|
||||||
|
|
||||||
|
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/mimo/mimo-7b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: XiaomiMiMo/MiMo-7B-RL
|
||||||
|
trust_remote_code: true
|
||||||
|
revision_of_model: 6299b5a
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# CCE - N/A as of now
|
||||||
|
# plugins:
|
||||||
|
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -59,6 +59,7 @@ gradient_checkpointing: true
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
scaling_softmax: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [main README](../README.md))
|
- Installed Axolotl (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
|
||||||
|
- Installed Axolotl from source (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
### TIPS
|
### TIPS
|
||||||
|
|
||||||
|
|||||||
@@ -42,10 +42,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
|||||||
42
examples/plano/README.md
Normal file
42
examples/plano/README.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Finetune Katanemo's Plano-Orchestrator with Axolotl
|
||||||
|
|
||||||
|
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/plano/plano-4b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### Orchestration Prompt
|
||||||
|
|
||||||
|
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
|
||||||
|
|
||||||
|
### Tips
|
||||||
|
|
||||||
|
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Plano GitHub](https://github.com/katanemo/plano)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
65
examples/plano/plano-4b-qlora.yaml
Normal file
65
examples/plano/plano-4b-qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: katanemo/Plano-Orchestrator-4B
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
chat_template: qwen3
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-adamw-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: adamw
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# AdamW optimizer settings (standard LR for AdamW)
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
learning_rate: 0.0002
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-muon-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: muon
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# Muon optimizer settings
|
||||||
|
optimizer: muon
|
||||||
|
learning_rate: 0.02
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
285
examples/swanlab/README.md
Normal file
285
examples/swanlab/README.md
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
# SwanLab Integration Examples
|
||||||
|
|
||||||
|
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
|
||||||
|
|
||||||
|
## Examples Overview
|
||||||
|
|
||||||
|
### 1. DPO with Completion Logging
|
||||||
|
**File**: `dpo-swanlab-completions.yml`
|
||||||
|
|
||||||
|
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- Basic SwanLab experiment tracking
|
||||||
|
- Completion table logging (prompts, chosen/rejected responses, rewards)
|
||||||
|
- Memory-bounded buffer for long training runs
|
||||||
|
- Cloud sync configuration
|
||||||
|
|
||||||
|
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
|
||||||
|
|
||||||
|
**Quick start**:
|
||||||
|
```bash
|
||||||
|
export SWANLAB_API_KEY=your-api-key
|
||||||
|
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. LoRA with Performance Profiling
|
||||||
|
**File**: `lora-swanlab-profiling.yml`
|
||||||
|
|
||||||
|
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- SwanLab experiment tracking
|
||||||
|
- Automatic profiling of trainer methods
|
||||||
|
- Profiling metrics visualization
|
||||||
|
- Performance optimization guidance
|
||||||
|
|
||||||
|
**Best for**: Engineers optimizing training performance and comparing different configurations
|
||||||
|
|
||||||
|
**Quick start**:
|
||||||
|
```bash
|
||||||
|
export SWANLAB_API_KEY=your-api-key
|
||||||
|
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Full-Featured DPO Production Setup
|
||||||
|
**File**: `dpo-swanlab-full-featured.yml`
|
||||||
|
|
||||||
|
Comprehensive production-ready configuration with ALL SwanLab features enabled.
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- Experiment tracking with team workspace
|
||||||
|
- RLHF completion logging
|
||||||
|
- Performance profiling
|
||||||
|
- Lark (Feishu) team notifications
|
||||||
|
- Private deployment support
|
||||||
|
- Production checklist and troubleshooting
|
||||||
|
|
||||||
|
**Best for**: Production RLHF training with team collaboration
|
||||||
|
|
||||||
|
**Quick start**:
|
||||||
|
```bash
|
||||||
|
export SWANLAB_API_KEY=your-api-key
|
||||||
|
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||||
|
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||||
|
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Custom Trainer Profiling (Python)
|
||||||
|
**File**: `custom_trainer_profiling.py`
|
||||||
|
|
||||||
|
Python code examples showing how to add SwanLab profiling to custom trainers.
|
||||||
|
|
||||||
|
**Features**:
|
||||||
|
- `@swanlab_profile` decorator examples
|
||||||
|
- Context manager profiling for fine-grained timing
|
||||||
|
- `ProfilingConfig` for advanced filtering and throttling
|
||||||
|
- Multiple profiling patterns and best practices
|
||||||
|
|
||||||
|
**Best for**: Advanced users creating custom trainers
|
||||||
|
|
||||||
|
**Usage**:
|
||||||
|
```python
|
||||||
|
from custom_trainer_profiling import CustomTrainerWithProfiling
|
||||||
|
# See file for detailed examples and patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Feature Matrix
|
||||||
|
|
||||||
|
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|
||||||
|
|---------|----------|-------------------|-----------|-------------------|----------------|
|
||||||
|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||||
|
| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||||
|
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
|
||||||
|
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration Quick Reference
|
||||||
|
|
||||||
|
### Basic SwanLab Setup
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: my-project
|
||||||
|
swanlab_experiment_name: my-experiment
|
||||||
|
swanlab_mode: cloud # cloud, local, offline, disabled
|
||||||
|
```
|
||||||
|
|
||||||
|
### RLHF Completion Logging
|
||||||
|
```yaml
|
||||||
|
swanlab_log_completions: true
|
||||||
|
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||||
|
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lark Team Notifications
|
||||||
|
```yaml
|
||||||
|
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||||
|
swanlab_lark_secret: your-webhook-secret # Required for production
|
||||||
|
```
|
||||||
|
|
||||||
|
### Team Workspace
|
||||||
|
```yaml
|
||||||
|
swanlab_workspace: my-research-team
|
||||||
|
```
|
||||||
|
|
||||||
|
### Private Deployment
|
||||||
|
```yaml
|
||||||
|
swanlab_web_host: https://swanlab.yourcompany.com
|
||||||
|
swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
### Recommended: Environment Variable
|
||||||
|
```bash
|
||||||
|
export SWANLAB_API_KEY=your-api-key
|
||||||
|
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||||
|
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||||
|
```
|
||||||
|
|
||||||
|
### Alternative: Config File (less secure)
|
||||||
|
```yaml
|
||||||
|
swanlab_api_key: your-api-key
|
||||||
|
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||||
|
swanlab_lark_secret: your-webhook-secret
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
### Use Case 1: Migrate from WandB to SwanLab
|
||||||
|
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
|
||||||
|
```yaml
|
||||||
|
use_swanlab: true
|
||||||
|
use_wandb: false
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use Case 2: Analyze DPO Model Outputs
|
||||||
|
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
|
||||||
|
```yaml
|
||||||
|
swanlab_completion_log_interval: 50 # More frequent for short training
|
||||||
|
swanlab_completion_log_interval: 200 # Less frequent for long training
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use Case 3: Optimize Training Performance
|
||||||
|
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
|
||||||
|
- Baseline: `flash_attention: false, gradient_checkpointing: false`
|
||||||
|
- Flash Attention: `flash_attention: true`
|
||||||
|
- Gradient Checkpointing: `gradient_checkpointing: true`
|
||||||
|
- Both: `flash_attention: true, gradient_checkpointing: true`
|
||||||
|
|
||||||
|
Compare profiling metrics in SwanLab dashboard.
|
||||||
|
|
||||||
|
### Use Case 4: Production RLHF with Team Collaboration
|
||||||
|
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
|
||||||
|
```yaml
|
||||||
|
swanlab_workspace: ml-team
|
||||||
|
swanlab_lark_webhook_url: ...
|
||||||
|
swanlab_lark_secret: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Viewing Your Experiments
|
||||||
|
|
||||||
|
### Cloud Mode
|
||||||
|
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
|
||||||
|
|
||||||
|
**Dashboard sections**:
|
||||||
|
- **Metrics**: Training loss, learning rate, profiling metrics
|
||||||
|
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
|
||||||
|
- **Config**: Hyperparameters and configuration
|
||||||
|
- **System**: Resource usage (GPU, memory, CPU)
|
||||||
|
- **Files**: Logged artifacts
|
||||||
|
|
||||||
|
### Local Mode
|
||||||
|
```bash
|
||||||
|
swanlab watch ./swanlog
|
||||||
|
# Open browser to http://localhost:5092
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### SwanLab not initializing
|
||||||
|
```bash
|
||||||
|
# Check API key
|
||||||
|
echo $SWANLAB_API_KEY
|
||||||
|
|
||||||
|
# Verify SwanLab is installed
|
||||||
|
pip show swanlab
|
||||||
|
|
||||||
|
# Check config
|
||||||
|
grep -A 5 "use_swanlab" your-config.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Completions not appearing
|
||||||
|
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||||
|
- Check `swanlab_log_completions: true`
|
||||||
|
- Wait for `swanlab_completion_log_interval` steps
|
||||||
|
- Look for "Registered SwanLab RLHF completion logging" in logs
|
||||||
|
|
||||||
|
### Lark notifications not working
|
||||||
|
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
|
||||||
|
- Verify `SWANLAB_LARK_SECRET` is set correctly
|
||||||
|
- Check bot is added to Lark group chat
|
||||||
|
- Look for "Registered Lark notification callback" in logs
|
||||||
|
|
||||||
|
### Profiling metrics not appearing
|
||||||
|
- Verify `use_swanlab: true`
|
||||||
|
- Check SwanLab is initialized (look for init log message)
|
||||||
|
- Profiling metrics are under "profiling/" namespace
|
||||||
|
- Profiling auto-enabled when SwanLab is enabled
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Performance Notes
|
||||||
|
|
||||||
|
### Overhead Comparison
|
||||||
|
|
||||||
|
| Feature | Overhead per Step | Memory Usage |
|
||||||
|
|---------|------------------|--------------|
|
||||||
|
| Basic tracking | < 0.1% | ~10 MB |
|
||||||
|
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
|
||||||
|
| Profiling | < 0.1% | ~1 KB |
|
||||||
|
| **Total** | **< 0.7%** | **~10 MB** |
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
|
||||||
|
2. Adjust completion log interval based on training length (100-200 steps)
|
||||||
|
3. Keep completion buffer size reasonable (128-512)
|
||||||
|
4. Profile critical path methods first (training_step, compute_loss)
|
||||||
|
5. Use ProfilingConfig to throttle high-frequency operations
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
|
||||||
|
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
|
||||||
|
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
|
||||||
|
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Found an issue or have an improvement? Please submit a PR or open an issue:
|
||||||
|
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
|
||||||
|
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)
|
||||||
299
examples/swanlab/custom_trainer_profiling.py
Normal file
299
examples/swanlab/custom_trainer_profiling.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""Example: Custom Trainer with SwanLab Profiling
|
||||||
|
|
||||||
|
This example demonstrates how to add SwanLab profiling to your custom trainer.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- @swanlab_profile decorator for automatic profiling
|
||||||
|
- swanlab_profiling_context for fine-grained profiling
|
||||||
|
- ProfilingConfig for advanced filtering and throttling
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
1. Create your custom trainer extending AxolotlTrainer
|
||||||
|
2. Add @swanlab_profile decorators to methods you want to profile
|
||||||
|
3. Use swanlab_profiling_context for fine-grained profiling within methods
|
||||||
|
4. Enable SwanLab in your config (use_swanlab: true)
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- examples/swanlab/lora-swanlab-profiling.yml for config
|
||||||
|
- src/axolotl/integrations/swanlab/profiling.py for implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.integrations.swanlab.profiling import (
|
||||||
|
ProfilingConfig,
|
||||||
|
swanlab_profile,
|
||||||
|
swanlab_profiling_context,
|
||||||
|
swanlab_profiling_context_advanced,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTrainerWithProfiling(AxolotlTrainer):
|
||||||
|
"""Custom trainer with SwanLab profiling enabled.
|
||||||
|
|
||||||
|
This trainer demonstrates three profiling patterns:
|
||||||
|
1. Decorator-based profiling (@swanlab_profile)
|
||||||
|
2. Context manager profiling (swanlab_profiling_context)
|
||||||
|
3. Advanced profiling with filtering (ProfilingConfig)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Create custom profiling config for high-frequency operations
|
||||||
|
self.fast_op_config = ProfilingConfig(
|
||||||
|
enabled=True,
|
||||||
|
min_duration_ms=0.5, # Only log if duration > 0.5ms
|
||||||
|
log_interval=50, # Log every 50th call
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Pattern 1: Decorator-based Profiling
|
||||||
|
# ========================================================================
|
||||||
|
# Best for: Methods you always want to profile
|
||||||
|
# Overhead: ~2-5 microseconds per call (negligible)
|
||||||
|
|
||||||
|
@swanlab_profile
|
||||||
|
def training_step(self, model, inputs):
|
||||||
|
"""Main training step - always profile.
|
||||||
|
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||||
|
"""
|
||||||
|
return super().training_step(model, inputs)
|
||||||
|
|
||||||
|
@swanlab_profile
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
"""Loss computation - always profile.
|
||||||
|
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
|
||||||
|
"""
|
||||||
|
return super().compute_loss(model, inputs, return_outputs)
|
||||||
|
|
||||||
|
@swanlab_profile
|
||||||
|
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||||
|
"""Prediction step - always profile.
|
||||||
|
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
|
||||||
|
"""
|
||||||
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Pattern 2: Fine-grained Context Manager Profiling
|
||||||
|
# ========================================================================
|
||||||
|
# Best for: Profiling specific code blocks within a method
|
||||||
|
# Use case: When you want to profile forward vs backward separately
|
||||||
|
|
||||||
|
def complex_training_step(self, model, inputs):
|
||||||
|
"""Training step with fine-grained profiling.
|
||||||
|
|
||||||
|
Profiling metrics:
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
|
||||||
|
"""
|
||||||
|
# Profile just the forward pass
|
||||||
|
with swanlab_profiling_context(self, "forward_pass"):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
# Profile just the backward pass
|
||||||
|
with swanlab_profiling_context(self, "backward_pass"):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Profile optimizer step
|
||||||
|
with swanlab_profiling_context(self, "optimizer_step"):
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Pattern 3: Advanced Profiling with Filtering
|
||||||
|
# ========================================================================
|
||||||
|
# Best for: High-frequency operations where you want to throttle logging
|
||||||
|
# Use case: Methods called 100+ times per step
|
||||||
|
|
||||||
|
def _prepare_inputs(self, inputs):
|
||||||
|
"""Prepare inputs - throttled profiling.
|
||||||
|
|
||||||
|
This method is called frequently (once per batch), so we throttle
|
||||||
|
profiling to reduce overhead:
|
||||||
|
- Only log if duration > 0.5ms (skip very fast operations)
|
||||||
|
- Only log every 50th call (reduce logging frequency)
|
||||||
|
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
|
||||||
|
"""
|
||||||
|
with swanlab_profiling_context_advanced(
|
||||||
|
self, "prepare_inputs", config=self.fast_op_config
|
||||||
|
):
|
||||||
|
return super()._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
def _prepare_input_for_model(self, input_ids):
|
||||||
|
"""Another high-frequency operation - throttled profiling.
|
||||||
|
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
|
||||||
|
"""
|
||||||
|
with swanlab_profiling_context_advanced(
|
||||||
|
self, "prepare_input_for_model", config=self.fast_op_config
|
||||||
|
):
|
||||||
|
# Your custom input preparation logic
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Pattern 4: Exception-safe Profiling
|
||||||
|
# ========================================================================
|
||||||
|
# Profiling is exception-safe: duration is logged even if method raises
|
||||||
|
|
||||||
|
@swanlab_profile
|
||||||
|
def potentially_failing_method(self):
|
||||||
|
"""This method may raise an exception.
|
||||||
|
|
||||||
|
SwanLab profiling will still log the duration before re-raising.
|
||||||
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
|
||||||
|
"""
|
||||||
|
# Do some work
|
||||||
|
result = self._do_risky_computation()
|
||||||
|
|
||||||
|
# If this raises, profiling duration is still logged
|
||||||
|
if result < 0:
|
||||||
|
raise ValueError("Invalid result")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _do_risky_computation(self):
|
||||||
|
"""Placeholder for risky computation."""
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Advanced Example: Custom ProfilingConfig Per Method
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedProfilingTrainer(AxolotlTrainer):
|
||||||
|
"""Trainer with method-specific profiling configurations."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Different profiling configs for different method types
|
||||||
|
self.critical_path_config = ProfilingConfig(
|
||||||
|
enabled=True,
|
||||||
|
min_duration_ms=0.0, # Log everything on critical path
|
||||||
|
log_interval=1, # Log every call
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fast_path_config = ProfilingConfig(
|
||||||
|
enabled=True,
|
||||||
|
min_duration_ms=1.0, # Only log if > 1ms
|
||||||
|
log_interval=100, # Log every 100th call
|
||||||
|
)
|
||||||
|
|
||||||
|
self.debug_config = ProfilingConfig(
|
||||||
|
enabled=True,
|
||||||
|
min_duration_ms=0.0, # Log everything
|
||||||
|
log_interval=1, # Log every call
|
||||||
|
)
|
||||||
|
|
||||||
|
def training_step(self, model, inputs):
|
||||||
|
"""Critical path - log everything."""
|
||||||
|
with swanlab_profiling_context_advanced(
|
||||||
|
self, "training_step", config=self.critical_path_config
|
||||||
|
):
|
||||||
|
return super().training_step(model, inputs)
|
||||||
|
|
||||||
|
def _prepare_inputs(self, inputs):
|
||||||
|
"""Fast path - throttle logging."""
|
||||||
|
with swanlab_profiling_context_advanced(
|
||||||
|
self, "prepare_inputs", config=self.fast_path_config
|
||||||
|
):
|
||||||
|
return super()._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
def _debug_method(self, data):
|
||||||
|
"""Debug-only method - verbose logging."""
|
||||||
|
with swanlab_profiling_context_advanced(
|
||||||
|
self, "debug_method", config=self.debug_config
|
||||||
|
):
|
||||||
|
# Your debug logic
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# How to Use This Custom Trainer
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
To use this custom trainer:
|
||||||
|
|
||||||
|
1. Save this file to your project (e.g., my_custom_trainer.py)
|
||||||
|
|
||||||
|
2. Create a config file that uses your custom trainer:
|
||||||
|
|
||||||
|
# config.yml
|
||||||
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
|
||||||
|
# ... other config ...
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: my-profiling-experiment
|
||||||
|
|
||||||
|
# Optional: Specify custom trainer
|
||||||
|
# (Or modify axolotl to use your custom trainer class)
|
||||||
|
|
||||||
|
3. Run training:
|
||||||
|
|
||||||
|
export SWANLAB_API_KEY=your-api-key
|
||||||
|
accelerate launch -m axolotl.cli.train config.yml
|
||||||
|
|
||||||
|
4. View profiling metrics in SwanLab dashboard:
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||||
|
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||||
|
- etc.
|
||||||
|
|
||||||
|
5. Compare profiling metrics across runs:
|
||||||
|
- Run baseline without optimizations
|
||||||
|
- Run with flash_attention enabled
|
||||||
|
- Run with gradient_checkpointing enabled
|
||||||
|
- Compare profiling metrics to see performance impact
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tips for Effective Profiling
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Profile the critical path first:
|
||||||
|
- training_step, compute_loss, prediction_step
|
||||||
|
- These methods are called most frequently and have biggest impact
|
||||||
|
|
||||||
|
2. Use throttling for high-frequency operations:
|
||||||
|
- Methods called 100+ times per step
|
||||||
|
- Use log_interval=50 or log_interval=100
|
||||||
|
- Reduces profiling overhead and dashboard clutter
|
||||||
|
|
||||||
|
3. Filter noise with min_duration_ms:
|
||||||
|
- Set min_duration_ms=1.0 to skip very fast operations
|
||||||
|
- Focus on operations that actually take time
|
||||||
|
|
||||||
|
4. Compare across runs:
|
||||||
|
- Run same config multiple times to check consistency
|
||||||
|
- Compare different optimization strategies
|
||||||
|
- Track profiling trends over time
|
||||||
|
|
||||||
|
5. Monitor distributed training:
|
||||||
|
- Check for per-rank timing differences
|
||||||
|
- Look for stragglers (slower ranks)
|
||||||
|
- Identify synchronization bottlenecks
|
||||||
|
|
||||||
|
6. Disable profiling in production:
|
||||||
|
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
|
||||||
|
- DEFAULT_PROFILING_CONFIG.enabled = False
|
||||||
|
|
||||||
|
7. Exception handling:
|
||||||
|
- Profiling is exception-safe
|
||||||
|
- Duration logged even if method raises
|
||||||
|
- Useful for debugging methods that fail intermittently
|
||||||
|
"""
|
||||||
168
examples/swanlab/dpo-swanlab-completions.yml
Normal file
168
examples/swanlab/dpo-swanlab-completions.yml
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# SwanLab DPO Training Example with Completion Logging
|
||||||
|
#
|
||||||
|
# This example demonstrates DPO (Direct Preference Optimization) training
|
||||||
|
# with SwanLab integration for experiment tracking and completion table logging.
|
||||||
|
#
|
||||||
|
# Features enabled:
|
||||||
|
# - SwanLab experiment tracking
|
||||||
|
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
|
||||||
|
# - Lark (Feishu) team notifications (optional)
|
||||||
|
#
|
||||||
|
# To run:
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||||
|
|
||||||
|
# Model Configuration
|
||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
eos_token: <|eot_id|>
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
|
||||||
|
# LoRA Configuration
|
||||||
|
adapter: lora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
# DPO Configuration
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_property_mappings:
|
||||||
|
role: role
|
||||||
|
content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
# Dataset and Output
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/dpo-swanlab-out
|
||||||
|
|
||||||
|
# Training Configuration
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
micro_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
num_epochs: 4
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
# Performance
|
||||||
|
gradient_checkpointing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
# Checkpointing and Logging
|
||||||
|
logging_steps: 1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SwanLab Integration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
# Basic SwanLab Configuration
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: dpo-training
|
||||||
|
swanlab_experiment_name: llama-3-dpo-completions-demo
|
||||||
|
swanlab_description: "DPO training with completion table logging"
|
||||||
|
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||||
|
|
||||||
|
# SwanLab Authentication
|
||||||
|
# Recommended: Set via environment variable
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# Or set in config (less secure):
|
||||||
|
# swanlab_api_key: your-api-key
|
||||||
|
|
||||||
|
# Optional: Team workspace
|
||||||
|
# swanlab_workspace: my-research-team
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# RLHF Completion Table Logging
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# Automatically logs model completions to SwanLab for qualitative analysis:
|
||||||
|
# - Prompts from your DPO dataset
|
||||||
|
# - Chosen responses (preferred)
|
||||||
|
# - Rejected responses (non-preferred)
|
||||||
|
# - Reward differences
|
||||||
|
#
|
||||||
|
# View the table in SwanLab dashboard under "rlhf_completions"
|
||||||
|
|
||||||
|
swanlab_log_completions: true
|
||||||
|
swanlab_completion_log_interval: 100 # Log every 100 training steps
|
||||||
|
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
|
||||||
|
|
||||||
|
# Memory Usage Notes:
|
||||||
|
# - Buffer size 128: ~64 KB (default, recommended)
|
||||||
|
# - Buffer size 512: ~256 KB (for more historical completions)
|
||||||
|
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
|
||||||
|
|
||||||
|
# Performance Notes:
|
||||||
|
# - Completion logging overhead: < 0.5% per training step
|
||||||
|
# - Only logs every N steps to minimize impact
|
||||||
|
# - Memory-bounded buffer prevents memory leaks
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Optional: Lark (Feishu) Team Notifications
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# Get real-time training notifications in your team chat
|
||||||
|
# Uncomment to enable:
|
||||||
|
|
||||||
|
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||||
|
# swanlab_lark_secret: your-webhook-secret # Recommended for production
|
||||||
|
|
||||||
|
# Notifications sent for:
|
||||||
|
# - Training start
|
||||||
|
# - Training completion
|
||||||
|
# - Training errors
|
||||||
|
# - Metric milestones (if configured)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Optional: Private SwanLab Deployment
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# For enterprise users with private SwanLab deployment:
|
||||||
|
|
||||||
|
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||||
|
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Disable WandB if you're migrating from it
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# wandb_project:
|
||||||
|
# wandb_entity:
|
||||||
|
# use_wandb: false
|
||||||
329
examples/swanlab/dpo-swanlab-full-featured.yml
Normal file
329
examples/swanlab/dpo-swanlab-full-featured.yml
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
# SwanLab Full-Featured DPO Training Example
|
||||||
|
#
|
||||||
|
# This example demonstrates ALL SwanLab integration features:
|
||||||
|
# - Experiment tracking with cloud sync
|
||||||
|
# - RLHF completion table logging
|
||||||
|
# - Performance profiling
|
||||||
|
# - Lark (Feishu) team notifications
|
||||||
|
# - Team workspace collaboration
|
||||||
|
#
|
||||||
|
# Use this as a reference for production RLHF training setups.
|
||||||
|
#
|
||||||
|
# To run:
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||||
|
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||||
|
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Model Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
eos_token: <|eot_id|>
|
||||||
|
|
||||||
|
# Quantization for efficient training
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# LoRA Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true # Target all linear layers
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# DPO (Direct Preference Optimization) Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo # Enable DPO trainer
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_property_mappings:
|
||||||
|
role: role
|
||||||
|
content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Dataset and Output Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/dpo-swanlab-full-featured-out
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Training Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
micro_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
num_epochs: 4
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Optimization
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Precision and Performance
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Checkpointing and Logging
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SwanLab Integration - Full Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Basic SwanLab Configuration
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: dpo-production
|
||||||
|
swanlab_experiment_name: llama-3-dpo-full-featured-v1
|
||||||
|
swanlab_description: |
|
||||||
|
Production DPO training with all SwanLab features enabled:
|
||||||
|
- Completion table logging for qualitative analysis
|
||||||
|
- Performance profiling for optimization
|
||||||
|
- Lark notifications for team collaboration
|
||||||
|
|
||||||
|
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Team Collaboration
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Workspace for team collaboration (shared experiments)
|
||||||
|
swanlab_workspace: ml-research-team
|
||||||
|
|
||||||
|
# Authentication (recommended: use environment variable)
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# Or set in config (less secure):
|
||||||
|
# swanlab_api_key: your-api-key
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# RLHF Completion Table Logging
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Automatically logs model completions for qualitative analysis:
|
||||||
|
# - Prompts from your DPO dataset
|
||||||
|
# - Chosen responses (preferred)
|
||||||
|
# - Rejected responses (non-preferred)
|
||||||
|
# - Reward differences
|
||||||
|
#
|
||||||
|
# View in SwanLab dashboard under "rlhf_completions" table
|
||||||
|
|
||||||
|
swanlab_log_completions: true
|
||||||
|
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||||
|
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
|
||||||
|
|
||||||
|
# Buffer size recommendations:
|
||||||
|
# - 128: Default, ~64 KB memory (recommended for most cases)
|
||||||
|
# - 256: ~128 KB memory (this config, good for longer training)
|
||||||
|
# - 512: ~256 KB memory (maximum for very long runs)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Lark (Feishu) Team Notifications
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Get real-time training notifications in your team chat
|
||||||
|
#
|
||||||
|
# Notifications sent for:
|
||||||
|
# - Training start
|
||||||
|
# - Training completion
|
||||||
|
# - Training errors
|
||||||
|
# - Metric milestones (if configured)
|
||||||
|
|
||||||
|
# Recommended: Set via environment variables
|
||||||
|
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||||
|
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||||
|
|
||||||
|
# Or set in config (less secure):
|
||||||
|
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||||
|
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
|
||||||
|
|
||||||
|
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
|
||||||
|
# unauthorized parties from sending fake notifications to your team chat.
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Performance Profiling
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Profiling is automatically enabled when SwanLab is enabled.
|
||||||
|
# Metrics logged to SwanLab under "profiling/" namespace:
|
||||||
|
# profiling/Time taken: AxolotlTrainer.training_step
|
||||||
|
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||||
|
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||||
|
#
|
||||||
|
# Use these metrics to:
|
||||||
|
# - Identify bottlenecks in training loop
|
||||||
|
# - Compare performance across different configurations
|
||||||
|
# - Monitor performance regressions over time
|
||||||
|
# - Debug unexpected slowdowns
|
||||||
|
|
||||||
|
# For custom profiling in your own trainer, see:
|
||||||
|
# examples/swanlab/custom_trainer_profiling.py
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Optional: Private SwanLab Deployment
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# For enterprise users with private SwanLab deployment:
|
||||||
|
|
||||||
|
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||||
|
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Optional: Model Checkpointing to SwanLab
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Log model checkpoints to SwanLab (coming soon)
|
||||||
|
|
||||||
|
swanlab_log_model: false
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Disable Other Logging Tools (Recommended)
|
||||||
|
# ============================================================================
|
||||||
|
# Using multiple logging tools simultaneously can impact performance:
|
||||||
|
# - Expected overhead: ~1-2% per logger
|
||||||
|
# - Potential config/callback conflicts
|
||||||
|
#
|
||||||
|
# For production training, use ONLY SwanLab:
|
||||||
|
|
||||||
|
# wandb_project:
|
||||||
|
# use_wandb: false
|
||||||
|
#
|
||||||
|
# use_mlflow: false
|
||||||
|
#
|
||||||
|
# use_comet: false
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Expected Training Behavior
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# With this configuration, you should see:
|
||||||
|
#
|
||||||
|
# 1. SwanLab Initialization (rank 0 only):
|
||||||
|
# INFO: SwanLab initialized for project: dpo-production
|
||||||
|
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
|
||||||
|
# INFO: SwanLab mode: cloud
|
||||||
|
# INFO: SwanLab workspace: ml-research-team
|
||||||
|
#
|
||||||
|
# 2. Completion Logging (rank 0 only):
|
||||||
|
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
|
||||||
|
# (log_interval=100, max_buffer=256)
|
||||||
|
#
|
||||||
|
# 3. Lark Notifications (rank 0 only):
|
||||||
|
# INFO: Registered Lark notification callback with HMAC authentication
|
||||||
|
#
|
||||||
|
# 4. Distributed Training Detection (if multi-GPU):
|
||||||
|
# INFO: Distributed training detected (world_size=N)
|
||||||
|
# INFO: Only rank 0 will initialize SwanLab
|
||||||
|
# INFO: Other ranks will skip SwanLab to avoid conflicts
|
||||||
|
#
|
||||||
|
# 5. Training Start Notification (Lark):
|
||||||
|
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
|
||||||
|
#
|
||||||
|
# 6. Periodic Completion Logging:
|
||||||
|
# Every 100 steps, completion table is updated in SwanLab dashboard
|
||||||
|
#
|
||||||
|
# 7. Training Complete Notification (Lark):
|
||||||
|
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
|
||||||
|
# With link to SwanLab dashboard and final metrics
|
||||||
|
#
|
||||||
|
# 8. SwanLab Dashboard Shows:
|
||||||
|
# - Training metrics (loss, learning rate, etc.)
|
||||||
|
# - Completion table (rlhf_completions)
|
||||||
|
# - Profiling metrics (profiling/Time taken: ...)
|
||||||
|
# - Hyperparameters and configuration
|
||||||
|
# - System resource usage
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Production Checklist
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Before deploying to production, verify:
|
||||||
|
# ✅ SwanLab API key is set via environment variable (not in config)
|
||||||
|
# ✅ Lark webhook secret is set (required for HMAC authentication)
|
||||||
|
# ✅ Workspace is set to your team's workspace
|
||||||
|
# ✅ Experiment name is descriptive and unique
|
||||||
|
# ✅ Only SwanLab is enabled (other loggers disabled)
|
||||||
|
# ✅ Completion logging buffer size is appropriate for your training duration
|
||||||
|
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
|
||||||
|
# ✅ Test run completes successfully and shows up in SwanLab dashboard
|
||||||
|
# ✅ Lark notifications are received in team chat
|
||||||
|
# ✅ Profiling metrics are logged correctly
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Troubleshooting
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# If SwanLab initialization fails:
|
||||||
|
# 1. Check SWANLAB_API_KEY environment variable is set
|
||||||
|
# 2. Verify swanlab_project is set in config
|
||||||
|
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
|
||||||
|
# 4. Verify internet connectivity (for cloud mode)
|
||||||
|
|
||||||
|
# If Lark notifications not received:
|
||||||
|
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
|
||||||
|
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
|
||||||
|
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
|
||||||
|
# 4. Check training logs for "Registered Lark notification callback"
|
||||||
|
# 5. Verify bot is added to the target Lark group chat
|
||||||
|
|
||||||
|
# If completions not appearing in SwanLab:
|
||||||
|
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||||
|
# 2. Check swanlab_log_completions is true
|
||||||
|
# 3. Wait for log_interval steps (default: 100)
|
||||||
|
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
|
||||||
|
|
||||||
|
# If profiling metrics not appearing:
|
||||||
|
# 1. Verify use_swanlab is true
|
||||||
|
# 2. Check SwanLab is initialized (check logs)
|
||||||
|
# 3. Look under "profiling/" namespace in dashboard
|
||||||
|
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
|
||||||
|
|
||||||
|
# For more help:
|
||||||
|
# - SwanLab docs: https://docs.swanlab.cn
|
||||||
|
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
|
||||||
|
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues
|
||||||
178
examples/swanlab/lora-swanlab-profiling.yml
Normal file
178
examples/swanlab/lora-swanlab-profiling.yml
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# SwanLab LoRA Training Example with Performance Profiling
|
||||||
|
#
|
||||||
|
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
|
||||||
|
# for performance profiling and optimization.
|
||||||
|
#
|
||||||
|
# Features enabled:
|
||||||
|
# - SwanLab experiment tracking
|
||||||
|
# - Performance profiling (training step, forward/backward pass timing)
|
||||||
|
# - Real-time metrics visualization
|
||||||
|
#
|
||||||
|
# To run:
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||||
|
|
||||||
|
# Model Configuration
|
||||||
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
|
||||||
|
# Dataset Configuration
|
||||||
|
datasets:
|
||||||
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-swanlab-profiling-out
|
||||||
|
|
||||||
|
# LoRA Configuration
|
||||||
|
adapter: lora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Training Configuration
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
micro_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
# Performance
|
||||||
|
gradient_checkpointing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
# Checkpointing and Logging
|
||||||
|
logging_steps: 1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Loss Monitoring
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SwanLab Integration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
# Basic SwanLab Configuration
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: lora-profiling
|
||||||
|
swanlab_experiment_name: llama-3.2-1b-profiling-demo
|
||||||
|
swanlab_description: "LoRA fine-tuning with performance profiling"
|
||||||
|
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||||
|
|
||||||
|
# SwanLab Authentication
|
||||||
|
# Recommended: Set via environment variable
|
||||||
|
# export SWANLAB_API_KEY=your-api-key
|
||||||
|
# Or set in config (less secure):
|
||||||
|
# swanlab_api_key: your-api-key
|
||||||
|
|
||||||
|
# Optional: Team workspace
|
||||||
|
# swanlab_workspace: my-ml-team
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Performance Profiling
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# SwanLab automatically profiles trainer methods when enabled.
|
||||||
|
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
|
||||||
|
#
|
||||||
|
# Built-in profiling:
|
||||||
|
# - Minimal overhead (< 0.1% per step)
|
||||||
|
# - High-precision timing (microsecond accuracy)
|
||||||
|
# - Exception-safe (logs duration even if method fails)
|
||||||
|
#
|
||||||
|
# View profiling metrics in SwanLab dashboard:
|
||||||
|
# profiling/Time taken: AxolotlTrainer.training_step
|
||||||
|
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||||
|
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||||
|
#
|
||||||
|
# For custom profiling in your own trainer, see:
|
||||||
|
# examples/swanlab/custom_trainer_profiling.py
|
||||||
|
|
||||||
|
# Completion logging is disabled for non-RLHF trainers
|
||||||
|
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Optional: Compare with Multiple Runs
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# To compare profiling metrics across different configurations:
|
||||||
|
#
|
||||||
|
# 1. Run baseline without flash attention:
|
||||||
|
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
|
||||||
|
# flash_attention: false
|
||||||
|
#
|
||||||
|
# 2. Run with gradient checkpointing:
|
||||||
|
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
#
|
||||||
|
# 3. Run with both:
|
||||||
|
# swanlab_experiment_name: llama-3.2-1b-optimized
|
||||||
|
# flash_attention: true
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
#
|
||||||
|
# Then compare profiling metrics in SwanLab dashboard to see performance impact
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Optional: Lark (Feishu) Team Notifications
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# Get notified when profiling experiments complete:
|
||||||
|
|
||||||
|
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||||
|
# swanlab_lark_secret: your-webhook-secret
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Profiling Best Practices
|
||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# 1. Run multiple epochs to see profiling trends over time
|
||||||
|
# 2. Ignore first ~10 steps (warmup period, slower)
|
||||||
|
# 3. Look for outliers (steps that take significantly longer)
|
||||||
|
# 4. Compare profiling metrics before/after optimization changes
|
||||||
|
# 5. Monitor per-rank profiling in distributed training
|
||||||
|
#
|
||||||
|
# Common bottlenecks to profile:
|
||||||
|
# - training_step: Overall step time (should be consistent)
|
||||||
|
# - compute_loss: Loss computation (scales with sequence length)
|
||||||
|
# - prediction_step: Evaluation time (can be slow for large val sets)
|
||||||
|
#
|
||||||
|
# If you see inconsistent timing:
|
||||||
|
# - Check for data loading bottlenecks
|
||||||
|
# - Monitor GPU utilization (may be CPU-bound)
|
||||||
|
# - Check for gradient accumulation effects
|
||||||
|
# - Verify CUDA kernel synchronization
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Disable WandB if you're migrating from it
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# wandb_project:
|
||||||
|
# use_wandb: false
|
||||||
@@ -29,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
base_model: arcee-ai/Trinity-Nano-Preview
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
revision_of_model: 2ee94b0
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.48.2
|
bitsandbytes==0.49.1
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -14,20 +14,21 @@ huggingface_hub>=0.36.0
|
|||||||
peft>=0.18.0
|
peft>=0.18.0
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==4.57.1
|
transformers==4.57.1
|
||||||
accelerate==1.11.0
|
accelerate==1.12.0
|
||||||
datasets==4.4.1
|
datasets==4.4.2
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.18.3
|
||||||
trl==0.25.0
|
trl==0.25.1
|
||||||
hf_xet==1.2.0
|
hf_xet==1.2.0
|
||||||
kernels>=0.9.0
|
kernels==0.11.5
|
||||||
trackio
|
trackio>=0.13.0
|
||||||
|
typing-extensions>=4.15.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.49.1
|
gradio>=6.2.0,<7.0
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.3.0.post1
|
||||||
pydantic>=2.10.6
|
pydantic>=2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
@@ -67,8 +68,7 @@ openenv-core==0.1.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.7
|
axolotl-contribs-lgpl==0.0.7
|
||||||
axolotl-contribs-mit==0.0.5
|
axolotl-contribs-mit==0.0.6
|
||||||
|
|
||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
|
||||||
)
|
)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -156,7 +156,7 @@ extras_require = {
|
|||||||
"came_pytorch==0.1.3",
|
"came_pytorch==0.1.3",
|
||||||
],
|
],
|
||||||
"ray": [
|
"ray": [
|
||||||
"ray[train]",
|
"ray[train]>=2.52.1",
|
||||||
],
|
],
|
||||||
"vllm": [
|
"vllm": [
|
||||||
"vllm==0.10.0",
|
"vllm==0.10.0",
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ if launcher_args:
|
|||||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||||
|
|
||||||
# 1. Define a base image for your training job
|
# 1. Define a base image for your training job
|
||||||
# must use torch 2.7.0 for vllm
|
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu128-2.9.1"
|
||||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
|
||||||
|
|
||||||
# 2. Define the Runtime Environment for the Training Job
|
# 2. Define the Runtime Environment for the Training Job
|
||||||
# This includes start commands and environment variables.a
|
# This includes start commands and environment variables.a
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def get_image(self):
|
def get_image(self):
|
||||||
docker_tag = "main-py3.11-cu126-2.7.1"
|
docker_tag = "main-py3.11-cu128-2.9.1"
|
||||||
if self.config.docker_tag:
|
if self.config.docker_tag:
|
||||||
docker_tag = self.config.docker_tag
|
docker_tag = self.config.docker_tag
|
||||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.tee import prepare_debug_log
|
from axolotl.utils.tee import prepare_debug_log
|
||||||
|
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -227,6 +228,7 @@ def load_cfg(
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
|
"fp8": compute_supports_fp8(),
|
||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
@@ -245,6 +247,7 @@ def load_cfg(
|
|||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
|
setup_trackio_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||||
@@ -259,3 +262,11 @@ def load_cfg(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def compute_supports_fp8() -> bool:
|
||||||
|
try:
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
return compute_capability >= (9, 0)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
|||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
|||||||
outputs=[masked_preview, html_out],
|
outputs=[masked_preview, html_out],
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.queue().launch(
|
demo.launch(
|
||||||
show_api=False,
|
footer_links=["gradio", "settings"],
|
||||||
share=cfg.get("gradio_share", True),
|
share=cfg.get("gradio_share", True),
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
"""
|
|
||||||
monkeypatch for flex + packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
|
||||||
from transformers import Cache, PretrainedConfig
|
|
||||||
from transformers.masking_utils import (
|
|
||||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
|
||||||
_preprocess_mask_arguments,
|
|
||||||
and_masks,
|
|
||||||
causal_mask_function,
|
|
||||||
or_masks,
|
|
||||||
)
|
|
||||||
from transformers.utils import is_torch_greater_or_equal
|
|
||||||
|
|
||||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
|
||||||
|
|
||||||
|
|
||||||
def create_causal_mask(
|
|
||||||
config: PretrainedConfig,
|
|
||||||
input_embeds: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
past_key_values: Optional[Cache],
|
|
||||||
or_mask_function: Optional[Callable] = None,
|
|
||||||
and_mask_function: Optional[Callable] = None,
|
|
||||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
|
||||||
"""
|
|
||||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
|
||||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
|
||||||
to what is needed in the `modeling_xxx.py` files).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (`PretrainedConfig`):
|
|
||||||
The model config.
|
|
||||||
input_embeds (`torch.Tensor`):
|
|
||||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
|
||||||
batch size, query length and dtype.
|
|
||||||
attention_mask (`torch.Tensor`, optional):
|
|
||||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
|
||||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
|
||||||
past_key_values (`Cache`, optional):
|
|
||||||
The past key values, if we use a cache.
|
|
||||||
or_mask_function (`Callable`, optional):
|
|
||||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
|
||||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
|
||||||
and_mask_function (`Callable`, optional):
|
|
||||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
|
||||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
|
||||||
"""
|
|
||||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
|
||||||
if (
|
|
||||||
past_key_values
|
|
||||||
and hasattr(past_key_values, "is_sliding")
|
|
||||||
and False in past_key_values.is_sliding
|
|
||||||
):
|
|
||||||
layer_idx = past_key_values.is_sliding.index(False)
|
|
||||||
else:
|
|
||||||
layer_idx = 0
|
|
||||||
|
|
||||||
original_attention_mask = (
|
|
||||||
None
|
|
||||||
if attention_mask is None
|
|
||||||
else attention_mask.clone().to(cache_position.device)
|
|
||||||
)
|
|
||||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
|
||||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
|
||||||
)
|
|
||||||
if early_exit:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
batch_size, total_seq_len = cache_position.shape
|
|
||||||
key_length = total_seq_len
|
|
||||||
document_ids = torch.nn.functional.pad(
|
|
||||||
original_attention_mask, value=0, pad=(0, key_length)
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
|
||||||
if attention_mask is not None:
|
|
||||||
|
|
||||||
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
||||||
"""
|
|
||||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
|
||||||
and a block diagonal document mask.
|
|
||||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
|
||||||
for an illustration.
|
|
||||||
"""
|
|
||||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
|
||||||
document_mask = (
|
|
||||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
|
||||||
)
|
|
||||||
final_mask = causal_mask_ & document_mask
|
|
||||||
return final_mask
|
|
||||||
|
|
||||||
mask_factory_function = causal_doc_mask_mod
|
|
||||||
else:
|
|
||||||
mask_factory_function = causal_mask_function
|
|
||||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
|
||||||
|
|
||||||
# Do not allow skip if we are compiling (this is to match BC)
|
|
||||||
allow_is_causal_skip = (
|
|
||||||
not past_key_values.is_compileable if past_key_values is not None else True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allow slight deviations from causal mask
|
|
||||||
if or_mask_function is not None:
|
|
||||||
if not _is_torch_greater_or_equal_than_2_6:
|
|
||||||
raise ValueError(
|
|
||||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
|
||||||
)
|
|
||||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
|
||||||
allow_is_causal_skip = False
|
|
||||||
if and_mask_function is not None:
|
|
||||||
if not _is_torch_greater_or_equal_than_2_6:
|
|
||||||
raise ValueError(
|
|
||||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
|
||||||
)
|
|
||||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
|
||||||
allow_is_causal_skip = False
|
|
||||||
|
|
||||||
# We now create the mask
|
|
||||||
causal_mask = mask_interface(
|
|
||||||
batch_size=batch_size,
|
|
||||||
cache_position=cache_position,
|
|
||||||
kv_length=kv_length,
|
|
||||||
kv_offset=kv_offset,
|
|
||||||
mask_function=mask_factory_function,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
|
||||||
dtype=dtype, # Additional kwarg for eager
|
|
||||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
|
||||||
)
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
def patch_create_causal_mask(model_type):
|
|
||||||
import transformers.masking_utils
|
|
||||||
|
|
||||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
|
||||||
|
|
||||||
if model_type:
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and attention class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
module = __import__(module_path)
|
|
||||||
module.create_causal_mask = create_causal_mask
|
|
||||||
del sys.modules[module_path]
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not import attention class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
@@ -35,6 +35,7 @@ from axolotl.utils import (
|
|||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_opentelemetry_available,
|
is_opentelemetry_available,
|
||||||
|
is_trackio_available,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_trackio and is_trackio_available():
|
||||||
|
from axolotl.utils.callbacks.trackio_ import (
|
||||||
|
SaveAxolotlConfigtoTrackioCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||||
from axolotl.utils.callbacks.opentelemetry import (
|
from axolotl.utils.callbacks.opentelemetry import (
|
||||||
OpenTelemetryMetricsCallback,
|
OpenTelemetryMetricsCallback,
|
||||||
@@ -281,11 +290,22 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
if self.cfg.optimizer == "muon":
|
if self.cfg.optimizer == "muon":
|
||||||
from axolotl.contribs.mit.muon import (
|
_, device_mesh = build_parallelism_config(self.cfg)
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
if device_mesh is not None:
|
||||||
|
from axolotl.contribs.mit.muon.dist_muon import (
|
||||||
|
DistMuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = DistMuonOptimizerFactory
|
||||||
|
optimizer_kwargs["device_mesh"] = device_mesh
|
||||||
|
else:
|
||||||
|
from axolotl.contribs.mit.muon import (
|
||||||
|
MuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = MuonOptimizerFactory
|
||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "dion":
|
elif self.cfg.optimizer == "dion":
|
||||||
from axolotl.contribs.mit.dion import (
|
from axolotl.contribs.mit.dion import (
|
||||||
@@ -423,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
report_to.append("tensorboard")
|
report_to.append("tensorboard")
|
||||||
if self.cfg.use_comet:
|
if self.cfg.use_comet:
|
||||||
report_to.append("comet_ml")
|
report_to.append("comet_ml")
|
||||||
|
if self.cfg.use_trackio:
|
||||||
|
report_to.append("trackio")
|
||||||
|
|
||||||
training_args_kwargs["report_to"] = report_to
|
training_args_kwargs["report_to"] = report_to
|
||||||
|
|
||||||
@@ -430,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
elif self.cfg.use_mlflow:
|
elif self.cfg.use_mlflow:
|
||||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
|
elif self.cfg.use_trackio:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["run_name"] = None
|
training_args_kwargs["run_name"] = None
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.include_tkps:
|
if self.cfg.include_tkps:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
TokensPerSecondCallback(
|
TokensPerSecondCallback(
|
||||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
self.cfg.tensor_parallel_size,
|
||||||
|
self.cfg.context_parallel_size,
|
||||||
|
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
@@ -371,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
|
|
||||||
|
if self.cfg.use_dynamic_finetuning:
|
||||||
|
from axolotl.monkeypatch.loss.dft import dft_loss
|
||||||
|
|
||||||
|
trainer_kwargs["compute_loss_func"] = dft_loss
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@@ -49,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
TOKENS_STATE_FILE = "tokens_state."
|
||||||
|
|
||||||
REDUCTION_FNS = {
|
REDUCTION_FNS = {
|
||||||
"mean": torch.mean,
|
"mean": torch.mean,
|
||||||
"min": torch.min,
|
"min": torch.min,
|
||||||
@@ -348,24 +352,33 @@ class AxolotlTrainer(
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
# track number of tokens for tokens per second calculation
|
# track number of tokens for tokens per second calculation
|
||||||
if self.args.include_tkps:
|
if self.args.include_tkps and model.training:
|
||||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||||
|
total_tokens = inputs[inputs_key].numel()
|
||||||
|
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
|
||||||
|
|
||||||
if is_distributed():
|
if is_distributed():
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
trainable_tokens, op=torch.distributed.ReduceOp.SUM
|
||||||
)
|
)
|
||||||
if hasattr(self.state, "num_tokens"):
|
torch.distributed.all_reduce(
|
||||||
self.state.num_tokens = (
|
total_tokens, op=torch.distributed.ReduceOp.SUM
|
||||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
|
||||||
|
|
||||||
if hasattr(self.state, "total_tokens"):
|
if not hasattr(self.state, "tokens"):
|
||||||
self.state.total_tokens += num_tokens
|
self.state.tokens = {
|
||||||
else:
|
"trainable": torch.zeros(1),
|
||||||
self.state.total_tokens = num_tokens
|
"total": torch.zeros(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
# trainable tokens for throughput and total token slots for summaries
|
||||||
|
self.state.tokens["trainable"] = (
|
||||||
|
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||||
|
)
|
||||||
|
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
|
||||||
|
# Store per-step trainable tokens for throughput calculation
|
||||||
|
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||||
|
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
return self.orpo_compute_loss(
|
return self.orpo_compute_loss(
|
||||||
@@ -603,6 +616,7 @@ class AxolotlTrainer(
|
|||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||||
|
|
||||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||||
@@ -613,7 +627,18 @@ class AxolotlTrainer(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Metric reduction must be one of [mean, min, max, sum]"
|
"Metric reduction must be one of [mean, min, max, sum]"
|
||||||
)
|
)
|
||||||
logs[key] = round(fn(values).item(), 4)
|
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||||
|
|
||||||
|
if "loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||||
|
except OverflowError:
|
||||||
|
logs["ppl"] = float("inf")
|
||||||
|
if "eval_loss" in logs:
|
||||||
|
try:
|
||||||
|
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||||
|
except OverflowError:
|
||||||
|
logs["eval_ppl"] = float("inf")
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Add memory usage
|
# Add memory usage
|
||||||
@@ -625,17 +650,20 @@ class AxolotlTrainer(
|
|||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.args.include_tkps and train_eval == "train":
|
if (
|
||||||
|
self.args.include_tkps
|
||||||
|
and train_eval == "train"
|
||||||
|
and hasattr(self.state, "tokens")
|
||||||
|
):
|
||||||
# each rank will log its own tokens per second
|
# each rank will log its own tokens per second
|
||||||
# for logging_steps > 1 we obtain a moving average of this metric
|
# for logging_steps > 1 we obtain a moving average of this metric
|
||||||
logs["tokens_per_second_per_gpu"] = round(
|
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
)
|
)
|
||||||
if (
|
if "total" in self.state.tokens:
|
||||||
hasattr(self.state, "total_tokens")
|
logs["tokens/total"] = int(self.state.tokens["total"].item())
|
||||||
and self.state.total_tokens is not None
|
if "trainable" in self.state.tokens:
|
||||||
):
|
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
|
||||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
@@ -670,6 +698,19 @@ class AxolotlTrainer(
|
|||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Save total_tokens state if tracking is enabled
|
||||||
|
if self.args.include_tkps and hasattr(self.state, "tokens"):
|
||||||
|
tokens_state = {
|
||||||
|
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
|
||||||
|
"trainable": int(
|
||||||
|
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
|
||||||
|
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(tokens_state, f)
|
||||||
|
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||||
|
|||||||
@@ -36,4 +36,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||||
if cfg.dpo_use_logits_to_keep is not None:
|
if cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||||
|
if cfg.dpo_use_liger_kernel is not None:
|
||||||
|
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -54,6 +54,8 @@ plugins:
|
|||||||
- granitemoehybrid
|
- granitemoehybrid
|
||||||
- hunyuan_v1_dense
|
- hunyuan_v1_dense
|
||||||
- hunyuan_v1_moe
|
- hunyuan_v1_moe
|
||||||
|
- internvl
|
||||||
|
- kimi_linear
|
||||||
- lfm2
|
- lfm2
|
||||||
- lfm2_moe
|
- lfm2_moe
|
||||||
- lfm2_vl
|
- lfm2_vl
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -96,7 +96,11 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
|
||||||
|
cce_patch(
|
||||||
|
cfg.model_config_type,
|
||||||
|
remote_model_id=cfg.base_model if cfg.trust_remote_code else None,
|
||||||
|
)
|
||||||
|
|
||||||
def patch_llama_like(
|
def patch_llama_like(
|
||||||
self,
|
self,
|
||||||
@@ -107,7 +111,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||||
|
|
||||||
def patch_generic(maybe_model, patch_options, model_type: str):
|
def patch_generic(
|
||||||
|
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||||
|
):
|
||||||
import cut_cross_entropy.transformers.llama
|
import cut_cross_entropy.transformers.llama
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
|
|||||||
if cfg.dense_mixer:
|
if cfg.dense_mixer:
|
||||||
if not importlib.util.find_spec("densemixer"):
|
if not importlib.util.find_spec("densemixer"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
"DenseMixer is not installed. Install it with `pip install densemixer`"
|
||||||
)
|
)
|
||||||
|
|
||||||
from densemixer.patching import (
|
from densemixer.patching import (
|
||||||
|
|||||||
1284
src/axolotl/integrations/swanlab/README.md
Normal file
1284
src/axolotl/integrations/swanlab/README.md
Normal file
File diff suppressed because it is too large
Load Diff
6
src/axolotl/integrations/swanlab/__init__.py
Normal file
6
src/axolotl/integrations/swanlab/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""SwanLab integration plugin for Axolotl"""
|
||||||
|
|
||||||
|
from axolotl.integrations.swanlab.args import SwanLabConfig
|
||||||
|
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
||||||
|
|
||||||
|
__all__ = ["SwanLabConfig", "SwanLabPlugin"]
|
||||||
140
src/axolotl/integrations/swanlab/args.py
Normal file
140
src/axolotl/integrations/swanlab/args.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""SwanLab configuration arguments"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class SwanLabConfig(BaseModel):
|
||||||
|
"""SwanLab configuration subset"""
|
||||||
|
|
||||||
|
use_swanlab: bool | None = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable SwanLab experiment tracking and visualization"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_project: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Your SwanLab project name"},
|
||||||
|
)
|
||||||
|
swanlab_experiment_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
|
||||||
|
)
|
||||||
|
swanlab_description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Description for your SwanLab experiment"},
|
||||||
|
)
|
||||||
|
swanlab_mode: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_workspace: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "SwanLab workspace name (organization or username)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_api_key: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_log_model: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_web_host: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Web address for SwanLab cloud environment (for private deployment)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_api_host: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "API address for SwanLab cloud environment (for private deployment)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_lark_webhook_url: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_lark_secret: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_log_completions: bool | None = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_completion_log_interval: int | None = Field(
|
||||||
|
default=100,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of training steps between completion table logging to SwanLab"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
swanlab_completion_max_buffer: int | None = Field(
|
||||||
|
default=128,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("swanlab_mode")
|
||||||
|
@classmethod
|
||||||
|
def validate_swanlab_mode(cls, v):
|
||||||
|
"""Validate swanlab_mode is one of the allowed values."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||||
|
if v not in valid_modes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid swanlab_mode: '{v}'.\n\n"
|
||||||
|
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||||
|
f"Examples:\n"
|
||||||
|
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||||
|
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||||
|
f" swanlab_mode: offline # Save metadata locally\n"
|
||||||
|
f" swanlab_mode: disabled # Turn off SwanLab\n"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("swanlab_project")
|
||||||
|
@classmethod
|
||||||
|
def validate_swanlab_project(cls, v):
|
||||||
|
"""Validate swanlab_project is non-empty when provided."""
|
||||||
|
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"swanlab_project cannot be an empty string.\n\n"
|
||||||
|
"Either:\n"
|
||||||
|
" 1. Provide a valid project name: swanlab_project: my-project\n"
|
||||||
|
" 2. Remove the swanlab_project field entirely\n"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_swanlab_enabled_requires_project(self):
|
||||||
|
"""Validate that if use_swanlab is True, swanlab_project must be set."""
|
||||||
|
if self.use_swanlab is True and not self.swanlab_project:
|
||||||
|
raise ValueError(
|
||||||
|
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
|
||||||
|
"Solutions:\n"
|
||||||
|
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||||
|
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||||
|
"Example:\n"
|
||||||
|
" use_swanlab: true\n"
|
||||||
|
" swanlab_project: my-llm-training\n"
|
||||||
|
)
|
||||||
|
return self
|
||||||
179
src/axolotl/integrations/swanlab/callbacks.py
Normal file
179
src/axolotl/integrations/swanlab/callbacks.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""SwanLab callbacks for Axolotl trainers.
|
||||||
|
|
||||||
|
This module provides HuggingFace Trainer callbacks for logging
|
||||||
|
RLHF completions to SwanLab.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SwanLabRLHFCompletionCallback(TrainerCallback):
|
||||||
|
"""Callback for logging RLHF completions to SwanLab.
|
||||||
|
|
||||||
|
This callback periodically logs model completions (prompts, chosen/rejected
|
||||||
|
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||||
|
|
||||||
|
Supports DPO, KTO, ORPO, and GRPO trainers.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> callback = SwanLabRLHFCompletionCallback(
|
||||||
|
... log_interval=100, # Log every 100 steps
|
||||||
|
... max_completions=128, # Keep last 128 completions
|
||||||
|
... )
|
||||||
|
>>> trainer.add_callback(callback)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
logger: CompletionLogger instance
|
||||||
|
log_interval: Number of steps between SwanLab logging
|
||||||
|
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_interval: int = 100,
|
||||||
|
max_completions: int = 128,
|
||||||
|
table_name: str = "rlhf_completions",
|
||||||
|
):
|
||||||
|
"""Initialize SwanLab RLHF completion callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_interval: Log to SwanLab every N steps. Default: 100
|
||||||
|
max_completions: Maximum completions to buffer. Default: 128
|
||||||
|
table_name: SwanLab table name. Default: "rlhf_completions"
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.logger = CompletionLogger(maxlen=max_completions)
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.table_name = table_name
|
||||||
|
self.trainer_type: str | None = None # Auto-detected
|
||||||
|
self._last_logged_step = 0
|
||||||
|
|
||||||
|
def on_init_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Detect trainer type on initialization."""
|
||||||
|
trainer = kwargs.get("trainer")
|
||||||
|
if trainer is not None:
|
||||||
|
trainer_name = trainer.__class__.__name__
|
||||||
|
if "DPO" in trainer_name:
|
||||||
|
self.trainer_type = "dpo"
|
||||||
|
elif "KTO" in trainer_name:
|
||||||
|
self.trainer_type = "kto"
|
||||||
|
elif "ORPO" in trainer_name:
|
||||||
|
self.trainer_type = "orpo"
|
||||||
|
elif "GRPO" in trainer_name:
|
||||||
|
self.trainer_type = "grpo"
|
||||||
|
else:
|
||||||
|
self.trainer_type = "unknown"
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"SwanLab RLHF completion logging enabled for {trainer_name} "
|
||||||
|
f"(type: {self.trainer_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs: dict | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Capture completions from logs and buffer them.
|
||||||
|
|
||||||
|
Different trainers log completions in different formats:
|
||||||
|
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
|
||||||
|
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
|
||||||
|
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
|
||||||
|
- GRPO: logs['grpo/completion'], logs['grpo/reward']
|
||||||
|
|
||||||
|
Note: This is a placeholder implementation. Actual log keys depend
|
||||||
|
on the TRL trainer implementation. You may need to patch the trainers
|
||||||
|
to expose completion data in logs.
|
||||||
|
"""
|
||||||
|
if logs is None or self.trainer_type is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
step = state.global_step
|
||||||
|
|
||||||
|
# DPO completions
|
||||||
|
if self.trainer_type == "dpo":
|
||||||
|
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
|
||||||
|
self.logger.add_dpo_completion(
|
||||||
|
step=step,
|
||||||
|
prompt=logs.get("dpo/prompt", ""),
|
||||||
|
chosen=logs.get("dpo/chosen", ""),
|
||||||
|
rejected=logs.get("dpo/rejected", ""),
|
||||||
|
reward_diff=logs.get("dpo/reward_diff"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# KTO completions
|
||||||
|
elif self.trainer_type == "kto":
|
||||||
|
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
|
||||||
|
self.logger.add_kto_completion(
|
||||||
|
step=step,
|
||||||
|
prompt=logs.get("kto/prompt", ""),
|
||||||
|
completion=logs.get("kto/completion", ""),
|
||||||
|
label=logs.get("kto/label", False),
|
||||||
|
reward=logs.get("kto/reward"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ORPO completions
|
||||||
|
elif self.trainer_type == "orpo":
|
||||||
|
if all(
|
||||||
|
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
|
||||||
|
):
|
||||||
|
self.logger.add_orpo_completion(
|
||||||
|
step=step,
|
||||||
|
prompt=logs.get("orpo/prompt", ""),
|
||||||
|
chosen=logs.get("orpo/chosen", ""),
|
||||||
|
rejected=logs.get("orpo/rejected", ""),
|
||||||
|
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# GRPO completions
|
||||||
|
elif self.trainer_type == "grpo":
|
||||||
|
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
|
||||||
|
self.logger.add_grpo_completion(
|
||||||
|
step=step,
|
||||||
|
prompt=logs.get("grpo/prompt", ""),
|
||||||
|
completion=logs.get("grpo/completion", ""),
|
||||||
|
reward=logs.get("grpo/reward"),
|
||||||
|
advantage=logs.get("grpo/advantage"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Periodically log to SwanLab
|
||||||
|
if step - self._last_logged_step >= self.log_interval:
|
||||||
|
if len(self.logger) > 0:
|
||||||
|
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||||
|
self.logger.clear()
|
||||||
|
self._last_logged_step = step
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Log remaining completions at end of training."""
|
||||||
|
if len(self.logger) > 0:
|
||||||
|
LOG.info(
|
||||||
|
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
|
||||||
|
)
|
||||||
|
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||||
|
self._last_logged_step = state.global_step
|
||||||
228
src/axolotl/integrations/swanlab/completion_logger.py
Normal file
228
src/axolotl/integrations/swanlab/completion_logger.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
|
||||||
|
|
||||||
|
This module provides utilities for logging model completions during
|
||||||
|
preference training to SwanLab for qualitative analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionLogger:
|
||||||
|
"""Memory-bounded logger for RLHF completions.
|
||||||
|
|
||||||
|
Stores prompts, completions, and rewards in fixed-size deques to prevent
|
||||||
|
memory leaks during long training runs. Logs completion tables to SwanLab
|
||||||
|
for qualitative analysis of model outputs.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> logger = CompletionLogger(maxlen=128)
|
||||||
|
>>> logger.add_dpo_completion(
|
||||||
|
... step=0,
|
||||||
|
... prompt="What is AI?",
|
||||||
|
... chosen="Artificial Intelligence is...",
|
||||||
|
... rejected="AI means...",
|
||||||
|
... reward_diff=0.5
|
||||||
|
... )
|
||||||
|
>>> logger.log_to_swanlab()
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
maxlen: Maximum number of completions to store (older ones are dropped)
|
||||||
|
data: Deque storing completion dictionaries
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxlen: int = 128):
|
||||||
|
"""Initialize completion logger with bounded buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maxlen: Maximum number of completions to store. When the buffer
|
||||||
|
is full, oldest completions are automatically discarded.
|
||||||
|
Default: 128 (sufficient for most RLHF runs without memory issues)
|
||||||
|
"""
|
||||||
|
self.maxlen = maxlen
|
||||||
|
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
|
||||||
|
|
||||||
|
def add_dpo_completion(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
prompt: str,
|
||||||
|
chosen: str,
|
||||||
|
rejected: str,
|
||||||
|
reward_diff: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a DPO completion to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: Training step number
|
||||||
|
prompt: Input prompt
|
||||||
|
chosen: Chosen (preferred) completion
|
||||||
|
rejected: Rejected (non-preferred) completion
|
||||||
|
reward_diff: Reward difference (chosen - rejected), if available
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"step": step,
|
||||||
|
"prompt": prompt,
|
||||||
|
"chosen": chosen,
|
||||||
|
"rejected": rejected,
|
||||||
|
}
|
||||||
|
if reward_diff is not None:
|
||||||
|
entry["reward_diff"] = reward_diff
|
||||||
|
|
||||||
|
self.data.append(entry)
|
||||||
|
|
||||||
|
def add_kto_completion(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
prompt: str,
|
||||||
|
completion: str,
|
||||||
|
label: bool,
|
||||||
|
reward: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a KTO completion to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: Training step number
|
||||||
|
prompt: Input prompt
|
||||||
|
completion: Model-generated completion
|
||||||
|
label: True if desirable, False if undesirable
|
||||||
|
reward: Reward score, if available
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"step": step,
|
||||||
|
"prompt": prompt,
|
||||||
|
"completion": completion,
|
||||||
|
"label": "desirable" if label else "undesirable",
|
||||||
|
}
|
||||||
|
if reward is not None:
|
||||||
|
entry["reward"] = reward
|
||||||
|
|
||||||
|
self.data.append(entry)
|
||||||
|
|
||||||
|
def add_orpo_completion(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
prompt: str,
|
||||||
|
chosen: str,
|
||||||
|
rejected: str,
|
||||||
|
log_odds_ratio: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add an ORPO completion to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: Training step number
|
||||||
|
prompt: Input prompt
|
||||||
|
chosen: Chosen (preferred) completion
|
||||||
|
rejected: Rejected (non-preferred) completion
|
||||||
|
log_odds_ratio: Log odds ratio between chosen and rejected
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"step": step,
|
||||||
|
"prompt": prompt,
|
||||||
|
"chosen": chosen,
|
||||||
|
"rejected": rejected,
|
||||||
|
}
|
||||||
|
if log_odds_ratio is not None:
|
||||||
|
entry["log_odds_ratio"] = log_odds_ratio
|
||||||
|
|
||||||
|
self.data.append(entry)
|
||||||
|
|
||||||
|
def add_grpo_completion(
|
||||||
|
self,
|
||||||
|
step: int,
|
||||||
|
prompt: str,
|
||||||
|
completion: str,
|
||||||
|
reward: float | None = None,
|
||||||
|
advantage: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add a GRPO completion to the buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: Training step number
|
||||||
|
prompt: Input prompt
|
||||||
|
completion: Model-generated completion
|
||||||
|
reward: Reward score from reward model
|
||||||
|
advantage: Advantage estimate (reward - baseline)
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"step": step,
|
||||||
|
"prompt": prompt,
|
||||||
|
"completion": completion,
|
||||||
|
}
|
||||||
|
if reward is not None:
|
||||||
|
entry["reward"] = reward
|
||||||
|
if advantage is not None:
|
||||||
|
entry["advantage"] = advantage
|
||||||
|
|
||||||
|
self.data.append(entry)
|
||||||
|
|
||||||
|
def log_to_swanlab(self, table_name: str = "completions") -> bool:
|
||||||
|
"""Log buffered completions to SwanLab as a table.
|
||||||
|
|
||||||
|
Creates a SwanLab echarts Table with all buffered completions.
|
||||||
|
Only logs if SwanLab is initialized and data is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table in SwanLab dashboard.
|
||||||
|
Default: "completions"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if logging succeeded, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.data:
|
||||||
|
LOG.debug("No completions to log to SwanLab")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
if swanlab.get_run() is None:
|
||||||
|
LOG.debug("SwanLab not initialized, skipping completion logging")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Convert deque to list of dicts
|
||||||
|
completions = list(self.data)
|
||||||
|
|
||||||
|
# Extract headers from first entry (all entries should have same structure)
|
||||||
|
headers = list(completions[0].keys())
|
||||||
|
|
||||||
|
# Build rows: each completion becomes one row
|
||||||
|
rows = []
|
||||||
|
for completion in completions:
|
||||||
|
row = [completion.get(header, "") for header in headers]
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
# Log to SwanLab as echarts Table
|
||||||
|
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
|
||||||
|
|
||||||
|
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"SwanLab not installed, cannot log completions. "
|
||||||
|
"Install with: pip install swanlab"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.exception("Failed to log completions to SwanLab: %s", err)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all buffered completions."""
|
||||||
|
self.data.clear()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return number of buffered completions."""
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation showing buffer status."""
|
||||||
|
return (
|
||||||
|
f"CompletionLogger(maxlen={self.maxlen}, "
|
||||||
|
f"buffered={len(self.data)}/{self.maxlen})"
|
||||||
|
)
|
||||||
554
src/axolotl/integrations/swanlab/plugins.py
Normal file
554
src/axolotl/integrations/swanlab/plugins.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
"""SwanLab Plugin for Axolotl"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SwanLabPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
SwanLab integration plugin for Axolotl.
|
||||||
|
|
||||||
|
Provides experiment tracking, visualization, and logging capabilities
|
||||||
|
using SwanLab (https://swanlab.cn).
|
||||||
|
|
||||||
|
Usage in config.yaml:
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||||
|
|
||||||
|
use_swanlab: true
|
||||||
|
swanlab_project: my-project
|
||||||
|
swanlab_experiment_name: my-experiment
|
||||||
|
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.swanlab_initialized = False
|
||||||
|
LOG.info("SwanLab plugin initialized")
|
||||||
|
|
||||||
|
def get_input_args(self) -> str:
|
||||||
|
"""Returns the configuration model for SwanLab integration."""
|
||||||
|
return "axolotl.integrations.swanlab.SwanLabConfig"
|
||||||
|
|
||||||
|
def register(self, cfg: dict):
|
||||||
|
"""Register SwanLab plugin with configuration and conflict detection."""
|
||||||
|
LOG.info("Registering SwanLab plugin")
|
||||||
|
|
||||||
|
# === Conflict Detection: Required Fields ===
|
||||||
|
|
||||||
|
# Check if SwanLab is enabled
|
||||||
|
if cfg.get("use_swanlab"):
|
||||||
|
# 1. Validate project name is set
|
||||||
|
if not cfg.get("swanlab_project"):
|
||||||
|
raise ValueError(
|
||||||
|
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
|
||||||
|
"Solutions:\n"
|
||||||
|
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||||
|
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||||
|
"See: src/axolotl/integrations/swanlab/README.md for examples"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Validate swanlab_mode value
|
||||||
|
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||||
|
mode = cfg.get("swanlab_mode")
|
||||||
|
if mode and mode not in valid_modes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid swanlab_mode: '{mode}'.\n\n"
|
||||||
|
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||||
|
f"Example:\n"
|
||||||
|
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||||
|
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Check API key for cloud mode
|
||||||
|
import os
|
||||||
|
|
||||||
|
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
|
||||||
|
if mode == "cloud":
|
||||||
|
api_key = cfg.get("swanlab_api_key") or os.environ.get(
|
||||||
|
"SWANLAB_API_KEY"
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
LOG.warning(
|
||||||
|
"SwanLab cloud mode enabled but no API key found.\n"
|
||||||
|
"SwanLab may fail to initialize during training.\n\n"
|
||||||
|
"Solutions:\n"
|
||||||
|
" 1. Set SWANLAB_API_KEY environment variable:\n"
|
||||||
|
" export SWANLAB_API_KEY=your-api-key\n"
|
||||||
|
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
|
||||||
|
" 3. Run 'swanlab login' before training\n"
|
||||||
|
" 4. Use 'swanlab_mode: local' for offline tracking\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Conflict Detection: Multi-Logger Performance Warning ===
|
||||||
|
|
||||||
|
# Detect all active logging tools
|
||||||
|
active_loggers = []
|
||||||
|
if cfg.get("use_wandb"):
|
||||||
|
active_loggers.append("WandB")
|
||||||
|
if cfg.get("use_mlflow"):
|
||||||
|
active_loggers.append("MLflow")
|
||||||
|
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
|
||||||
|
active_loggers.append("Comet")
|
||||||
|
if cfg.get("use_swanlab"):
|
||||||
|
active_loggers.append("SwanLab")
|
||||||
|
|
||||||
|
if len(active_loggers) > 1:
|
||||||
|
LOG.warning(
|
||||||
|
f"\n{'=' * 70}\n"
|
||||||
|
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
|
||||||
|
f"{'=' * 70}\n"
|
||||||
|
f"This may cause:\n"
|
||||||
|
f" - Performance overhead (~1-2% per logger, cumulative)\n"
|
||||||
|
f" - Increased memory usage\n"
|
||||||
|
f" - Longer training time per step\n"
|
||||||
|
f" - Potential config/callback conflicts\n\n"
|
||||||
|
f"Recommendations:\n"
|
||||||
|
f" - Choose ONE primary logging tool for production training\n"
|
||||||
|
f" - Use multiple loggers only for:\n"
|
||||||
|
f" * Migration period (transitioning between tools)\n"
|
||||||
|
f" * Short comparison runs\n"
|
||||||
|
f" * Debugging specific tool issues\n"
|
||||||
|
f" - Monitor system resources (CPU, memory) during training\n"
|
||||||
|
f"{'=' * 70}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(active_loggers) >= 3:
|
||||||
|
LOG.error(
|
||||||
|
f"\n{'!' * 70}\n"
|
||||||
|
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
|
||||||
|
f"{'!' * 70}\n"
|
||||||
|
f"This is likely unintentional and WILL significantly impact performance.\n"
|
||||||
|
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
|
||||||
|
f"STRONGLY RECOMMEND:\n"
|
||||||
|
f" - Disable all but ONE logging tool\n"
|
||||||
|
f" - Use config inheritance to manage multiple configs\n"
|
||||||
|
f"{'!' * 70}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# === Auto-Enable Logic ===
|
||||||
|
|
||||||
|
# Enable SwanLab if project is specified
|
||||||
|
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
|
||||||
|
cfg["use_swanlab"] = True
|
||||||
|
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg: DictDefault):
|
||||||
|
"""Initialize SwanLab before model loading with runtime checks."""
|
||||||
|
if not cfg.use_swanlab:
|
||||||
|
return
|
||||||
|
|
||||||
|
# === Runtime Check: Import Availability ===
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError(
|
||||||
|
"SwanLab is not installed.\n\n"
|
||||||
|
"Install with:\n"
|
||||||
|
" pip install swanlab\n\n"
|
||||||
|
"Or add to requirements:\n"
|
||||||
|
" swanlab>=0.3.0\n\n"
|
||||||
|
f"Original error: {err}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
# Log SwanLab version
|
||||||
|
try:
|
||||||
|
swanlab_version = swanlab.__version__
|
||||||
|
LOG.info(f"SwanLab version: {swanlab_version}")
|
||||||
|
except AttributeError:
|
||||||
|
LOG.warning("Could not determine SwanLab version")
|
||||||
|
|
||||||
|
# === Runtime Check: Distributed Training Setup ===
|
||||||
|
from axolotl.utils.distributed import get_world_size, is_main_process
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size > 1:
|
||||||
|
mode = getattr(cfg, "swanlab_mode", "cloud")
|
||||||
|
LOG.info(
|
||||||
|
f"\n{'=' * 70}\n"
|
||||||
|
f"Distributed training detected (world_size={world_size})\n"
|
||||||
|
f"SwanLab mode: {mode}\n"
|
||||||
|
f"{'=' * 70}\n"
|
||||||
|
f"Behavior:\n"
|
||||||
|
f" - Only rank 0 will initialize SwanLab\n"
|
||||||
|
f" - Other ranks will skip SwanLab to avoid conflicts\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == "cloud":
|
||||||
|
LOG.info(
|
||||||
|
f" - Only rank 0 will upload to SwanLab cloud\n"
|
||||||
|
f" - Other ranks run without SwanLab overhead\n"
|
||||||
|
f"{'=' * 70}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only initialize SwanLab on the main process (rank 0)
|
||||||
|
# to avoid creating multiple runs in distributed training
|
||||||
|
if not is_main_process():
|
||||||
|
LOG.debug("Skipping SwanLab initialization on non-main process")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize SwanLab run (passing all params directly to init)
|
||||||
|
try:
|
||||||
|
init_kwargs = self._get_swanlab_init_kwargs(cfg)
|
||||||
|
swanlab.init(**init_kwargs)
|
||||||
|
self.swanlab_initialized = True
|
||||||
|
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
|
||||||
|
|
||||||
|
# Register Lark notification callback (if configured)
|
||||||
|
self._register_lark_callback(cfg)
|
||||||
|
|
||||||
|
# Log configuration (with error handling)
|
||||||
|
try:
|
||||||
|
config_dict = self._prepare_config_for_logging(cfg)
|
||||||
|
swanlab.config.update(config_dict)
|
||||||
|
LOG.debug("Successfully logged config to SwanLab")
|
||||||
|
except Exception as config_err: # pylint: disable=broad-except
|
||||||
|
LOG.warning(
|
||||||
|
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.exception("Failed to initialize SwanLab: %s", err)
|
||||||
|
self.swanlab_initialized = False
|
||||||
|
|
||||||
|
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
|
||||||
|
"""Add SwanLab callbacks before trainer creation."""
|
||||||
|
callbacks: list[TrainerCallback] = []
|
||||||
|
|
||||||
|
if not cfg.use_swanlab:
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
if not self.swanlab_initialized:
|
||||||
|
LOG.warning("SwanLab not initialized, skipping callback registration")
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
try:
|
||||||
|
from axolotl.utils.callbacks.swanlab import (
|
||||||
|
CustomSwanLabCallback,
|
||||||
|
SaveAxolotlConfigtoSwanLabCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add our custom lightweight SwanLabCallback
|
||||||
|
# (avoids omegaconf/antlr4 version conflicts)
|
||||||
|
swanlab_callback = CustomSwanLabCallback()
|
||||||
|
callbacks.append(swanlab_callback)
|
||||||
|
LOG.info("Added CustomSwanLabCallback for metrics logging")
|
||||||
|
|
||||||
|
# Add Axolotl config logging callback
|
||||||
|
if cfg.axolotl_config_path:
|
||||||
|
config_callback = SaveAxolotlConfigtoSwanLabCallback(
|
||||||
|
cfg.axolotl_config_path
|
||||||
|
)
|
||||||
|
callbacks.append(config_callback)
|
||||||
|
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
|
||||||
|
|
||||||
|
except ImportError as err:
|
||||||
|
LOG.exception("Failed to import SwanLab callbacks: %s", err)
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def post_trainer_create(self, cfg: DictDefault, trainer):
|
||||||
|
"""Post-trainer creation hook."""
|
||||||
|
if cfg.use_swanlab and self.swanlab_initialized:
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
# Log additional trainer information (with safe conversion)
|
||||||
|
trainer_config = {
|
||||||
|
"total_steps": int(trainer.state.max_steps)
|
||||||
|
if trainer.state.max_steps
|
||||||
|
else None,
|
||||||
|
"num_train_epochs": float(trainer.args.num_train_epochs)
|
||||||
|
if trainer.args.num_train_epochs
|
||||||
|
else None,
|
||||||
|
"train_batch_size": int(trainer.args.train_batch_size)
|
||||||
|
if hasattr(trainer.args, "train_batch_size")
|
||||||
|
else None,
|
||||||
|
"gradient_accumulation_steps": int(
|
||||||
|
trainer.args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
if trainer.args.gradient_accumulation_steps
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
# Remove None values
|
||||||
|
trainer_config = {
|
||||||
|
k: v for k, v in trainer_config.items() if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
if trainer_config:
|
||||||
|
swanlab.config.update(trainer_config)
|
||||||
|
LOG.info("Logged trainer configuration to SwanLab")
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
|
||||||
|
|
||||||
|
# Register RLHF completion logging callback if enabled
|
||||||
|
self._register_completion_callback(cfg, trainer)
|
||||||
|
|
||||||
|
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
|
||||||
|
"""Prepare kwargs for swanlab.init().
|
||||||
|
|
||||||
|
Passes all configuration parameters directly to swanlab.init()
|
||||||
|
instead of using environment variables as an intermediate layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Keyword arguments for swanlab.init()
|
||||||
|
"""
|
||||||
|
init_kwargs = {}
|
||||||
|
|
||||||
|
# Project name (required)
|
||||||
|
if cfg.swanlab_project:
|
||||||
|
init_kwargs["project"] = cfg.swanlab_project
|
||||||
|
|
||||||
|
# Experiment name
|
||||||
|
if cfg.swanlab_experiment_name:
|
||||||
|
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
|
||||||
|
|
||||||
|
# Description
|
||||||
|
if cfg.swanlab_description:
|
||||||
|
init_kwargs["description"] = cfg.swanlab_description
|
||||||
|
|
||||||
|
# Workspace (organization)
|
||||||
|
if cfg.swanlab_workspace:
|
||||||
|
init_kwargs["workspace"] = cfg.swanlab_workspace
|
||||||
|
|
||||||
|
# Mode: cloud, local, offline, disabled
|
||||||
|
if cfg.swanlab_mode:
|
||||||
|
init_kwargs["mode"] = cfg.swanlab_mode
|
||||||
|
|
||||||
|
# API key (pass directly instead of via env var)
|
||||||
|
if cfg.swanlab_api_key:
|
||||||
|
init_kwargs["api_key"] = cfg.swanlab_api_key
|
||||||
|
|
||||||
|
# Private deployment hosts (pass directly instead of via env var)
|
||||||
|
if cfg.swanlab_web_host:
|
||||||
|
init_kwargs["web_host"] = cfg.swanlab_web_host
|
||||||
|
|
||||||
|
if cfg.swanlab_api_host:
|
||||||
|
init_kwargs["api_host"] = cfg.swanlab_api_host
|
||||||
|
|
||||||
|
# Log model checkpoints (coming soon in SwanLab)
|
||||||
|
if cfg.swanlab_log_model:
|
||||||
|
init_kwargs["log_model"] = cfg.swanlab_log_model
|
||||||
|
|
||||||
|
# Custom branding - adds Axolotl identifier to SwanLab UI
|
||||||
|
# This helps identify runs from Axolotl vs other frameworks
|
||||||
|
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
|
||||||
|
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
|
||||||
|
"""Prepare configuration dict for logging to SwanLab."""
|
||||||
|
|
||||||
|
def safe_convert(value):
|
||||||
|
"""Convert value to JSON-serializable type."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, (int, float, bool)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
# Convert everything else to string
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract important training parameters with safe conversion
|
||||||
|
config_dict = {
|
||||||
|
"base_model": safe_convert(getattr(cfg, "base_model", "")),
|
||||||
|
"model_type": safe_convert(getattr(cfg, "model_type", "")),
|
||||||
|
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
|
||||||
|
"micro_batch_size": safe_convert(
|
||||||
|
getattr(cfg, "micro_batch_size", None)
|
||||||
|
),
|
||||||
|
"gradient_accumulation_steps": safe_convert(
|
||||||
|
getattr(cfg, "gradient_accumulation_steps", None)
|
||||||
|
),
|
||||||
|
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
|
||||||
|
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
|
||||||
|
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
|
||||||
|
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
|
||||||
|
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
|
||||||
|
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
|
||||||
|
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
|
||||||
|
"seed": safe_convert(getattr(cfg, "seed", None)),
|
||||||
|
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
||||||
|
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
||||||
|
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
|
||||||
|
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add FSDP/parallel config - only boolean flags
|
||||||
|
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
|
||||||
|
config_dict["fsdp_enabled"] = True
|
||||||
|
config_dict["fsdp_version"] = safe_convert(
|
||||||
|
getattr(cfg, "fsdp_version", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
|
||||||
|
config_dict["deepspeed_enabled"] = True
|
||||||
|
|
||||||
|
# Add context parallel info
|
||||||
|
if hasattr(cfg, "context_parallel_size"):
|
||||||
|
config_dict["context_parallel_size"] = safe_convert(
|
||||||
|
getattr(cfg, "context_parallel_size", None)
|
||||||
|
)
|
||||||
|
if hasattr(cfg, "tensor_parallel_size"):
|
||||||
|
config_dict["tensor_parallel_size"] = safe_convert(
|
||||||
|
getattr(cfg, "tensor_parallel_size", None)
|
||||||
|
)
|
||||||
|
if hasattr(cfg, "dp_shard_size"):
|
||||||
|
config_dict["dp_shard_size"] = safe_convert(
|
||||||
|
getattr(cfg, "dp_shard_size", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove None values and empty strings
|
||||||
|
config_dict = {
|
||||||
|
k: v
|
||||||
|
for k, v in config_dict.items()
|
||||||
|
if v is not None and v != "" and v != "None"
|
||||||
|
}
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.warning(f"Failed to prepare config for logging: {err}")
|
||||||
|
# Return minimal config
|
||||||
|
try:
|
||||||
|
lr = getattr(cfg, "learning_rate", None)
|
||||||
|
lr_value = float(lr) if lr is not None else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
lr_value = None
|
||||||
|
return {
|
||||||
|
"base_model": str(getattr(cfg, "base_model", "unknown")),
|
||||||
|
"learning_rate": lr_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _register_lark_callback(self, cfg: DictDefault):
|
||||||
|
"""Register Lark (Feishu) notification callback if configured.
|
||||||
|
|
||||||
|
Lark notifications enable sending training updates to team chat channels,
|
||||||
|
useful for production monitoring and team collaboration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration object with Lark webhook settings
|
||||||
|
"""
|
||||||
|
# Check if Lark webhook URL is configured
|
||||||
|
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
|
||||||
|
if not lark_webhook_url:
|
||||||
|
return # Lark not configured, skip
|
||||||
|
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
from swanlab.plugin.notification import LarkCallback
|
||||||
|
|
||||||
|
# Get optional secret for HMAC signature authentication
|
||||||
|
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
|
||||||
|
|
||||||
|
# Create Lark callback with webhook URL and optional secret
|
||||||
|
lark_callback = LarkCallback(
|
||||||
|
webhook_url=lark_webhook_url,
|
||||||
|
secret=lark_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register callback with SwanLab
|
||||||
|
swanlab.register_callbacks([lark_callback])
|
||||||
|
|
||||||
|
if lark_secret:
|
||||||
|
LOG.info(
|
||||||
|
"Registered Lark notification callback with HMAC authentication"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info("Registered Lark notification callback (no HMAC secret)")
|
||||||
|
LOG.warning(
|
||||||
|
"Lark webhook has no secret configured. "
|
||||||
|
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as err:
|
||||||
|
LOG.warning(
|
||||||
|
f"Failed to import SwanLab Lark plugin: {err}\n\n"
|
||||||
|
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
|
||||||
|
"Install with: pip install 'swanlab>=0.3.0'\n\n"
|
||||||
|
"Continuing without Lark notifications..."
|
||||||
|
)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.exception(
|
||||||
|
"Failed to register Lark callback: %s\n\n"
|
||||||
|
"Check your Lark webhook URL and secret configuration.\n"
|
||||||
|
"Continuing without Lark notifications...",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_completion_callback(self, cfg: DictDefault, trainer):
|
||||||
|
"""Register RLHF completion logging callback if enabled and applicable.
|
||||||
|
|
||||||
|
This callback logs model completions (prompts, chosen/rejected responses,
|
||||||
|
rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration object with completion logging settings
|
||||||
|
trainer: The trainer instance to add callback to
|
||||||
|
"""
|
||||||
|
# Check if completion logging is enabled
|
||||||
|
log_completions = getattr(cfg, "swanlab_log_completions", True)
|
||||||
|
if not log_completions:
|
||||||
|
LOG.debug("SwanLab completion logging disabled by config")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if trainer is an RLHF trainer
|
||||||
|
trainer_name = trainer.__class__.__name__
|
||||||
|
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
|
||||||
|
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
|
||||||
|
|
||||||
|
if not is_rlhf_trainer:
|
||||||
|
LOG.debug(
|
||||||
|
f"Trainer {trainer_name} is not an RLHF trainer, "
|
||||||
|
"skipping completion logging callback"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from axolotl.integrations.swanlab.callbacks import (
|
||||||
|
SwanLabRLHFCompletionCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get configuration parameters
|
||||||
|
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
|
||||||
|
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
|
||||||
|
|
||||||
|
# Create and register callback
|
||||||
|
completion_callback = SwanLabRLHFCompletionCallback(
|
||||||
|
log_interval=log_interval,
|
||||||
|
max_completions=max_buffer,
|
||||||
|
table_name="rlhf_completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.add_callback(completion_callback)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
|
||||||
|
f"(log_interval={log_interval}, max_buffer={max_buffer})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as err:
|
||||||
|
LOG.warning(
|
||||||
|
f"Failed to import SwanLab completion callback: {err}\n\n"
|
||||||
|
"This is a bug - the callback should be available.\n"
|
||||||
|
"Please report this issue.\n\n"
|
||||||
|
"Continuing without completion logging..."
|
||||||
|
)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.exception(
|
||||||
|
"Failed to register SwanLab completion callback: %s\n\n"
|
||||||
|
"Continuing without completion logging...",
|
||||||
|
err,
|
||||||
|
)
|
||||||
203
src/axolotl/integrations/swanlab/profiling.py
Normal file
203
src/axolotl/integrations/swanlab/profiling.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""SwanLab profiling utilities for Axolotl trainers.
|
||||||
|
|
||||||
|
This module provides decorators and context managers for profiling
|
||||||
|
trainer methods and logging execution times to SwanLab.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def swanlab_profiling_context(trainer: Any, func_name: str):
|
||||||
|
"""Context manager for profiling trainer methods.
|
||||||
|
|
||||||
|
Measures execution time and logs to SwanLab if enabled.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> with swanlab_profiling_context(self, "training_step"):
|
||||||
|
... result = do_expensive_computation()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
|
||||||
|
func_name: Name of the function being profiled
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# Check if SwanLab is enabled and initialized
|
||||||
|
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||||
|
if use_swanlab:
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
if swanlab.get_run() is not None:
|
||||||
|
# Log profiling metric
|
||||||
|
trainer_class = trainer.__class__.__name__
|
||||||
|
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||||
|
|
||||||
|
swanlab.log({metric_name: duration})
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# SwanLab not installed, silently skip
|
||||||
|
pass
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
# Log error but don't fail training
|
||||||
|
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||||
|
|
||||||
|
|
||||||
|
def swanlab_profile(func: Callable) -> Callable:
|
||||||
|
"""Decorator to profile and log function execution time to SwanLab.
|
||||||
|
|
||||||
|
Automatically measures execution time of trainer methods and logs
|
||||||
|
to SwanLab as profiling metrics.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> class MyTrainer:
|
||||||
|
... @swanlab_profile
|
||||||
|
... def training_step(self, model, inputs):
|
||||||
|
... return super().training_step(model, inputs)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to profile (must be a method of a trainer instance)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function with profiling
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
with swanlab_profiling_context(self, func.__name__):
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ProfilingConfig:
|
||||||
|
"""Configuration for SwanLab profiling.
|
||||||
|
|
||||||
|
This class provides a centralized way to control profiling behavior.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether profiling is enabled globally
|
||||||
|
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
|
||||||
|
log_interval: Log every N function calls (to reduce overhead)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool = True,
|
||||||
|
min_duration_ms: float = 0.1,
|
||||||
|
log_interval: int = 1,
|
||||||
|
):
|
||||||
|
"""Initialize profiling configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: Enable profiling. Default: True
|
||||||
|
min_duration_ms: Minimum duration to log (ms). Default: 0.1
|
||||||
|
log_interval: Log every N calls. Default: 1 (log all)
|
||||||
|
"""
|
||||||
|
self.enabled = enabled
|
||||||
|
self.min_duration_ms = min_duration_ms
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self._call_counts: dict[str, int] = {}
|
||||||
|
|
||||||
|
def should_log(self, func_name: str, duration_seconds: float) -> bool:
|
||||||
|
"""Check if a profiling measurement should be logged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_name: Name of the profiled function
|
||||||
|
duration_seconds: Execution duration in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if should log, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check minimum duration threshold
|
||||||
|
duration_ms = duration_seconds * 1000
|
||||||
|
if duration_ms < self.min_duration_ms:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check log interval
|
||||||
|
self._call_counts.setdefault(func_name, 0)
|
||||||
|
self._call_counts[func_name] += 1
|
||||||
|
|
||||||
|
# Always log on first call OR at intervals
|
||||||
|
count = self._call_counts[func_name]
|
||||||
|
if count == 1 or count % self.log_interval == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Global profiling config (can be modified by users)
|
||||||
|
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def swanlab_profiling_context_advanced(
|
||||||
|
trainer: Any,
|
||||||
|
func_name: str,
|
||||||
|
config: ProfilingConfig | None = None,
|
||||||
|
):
|
||||||
|
"""Advanced profiling context with configurable behavior.
|
||||||
|
|
||||||
|
Similar to swanlab_profiling_context but with additional configuration
|
||||||
|
options for filtering and throttling profiling logs.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
|
||||||
|
>>> with swanlab_profiling_context_advanced(self, "forward", config):
|
||||||
|
... output = model(inputs)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer: Trainer instance
|
||||||
|
func_name: Function name
|
||||||
|
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = DEFAULT_PROFILING_CONFIG
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# Check if should log based on config
|
||||||
|
if config.should_log(func_name, duration):
|
||||||
|
# Check if SwanLab is enabled
|
||||||
|
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||||
|
if use_swanlab:
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
if swanlab.get_run() is not None:
|
||||||
|
trainer_class = trainer.__class__.__name__
|
||||||
|
metric_name = (
|
||||||
|
f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
swanlab.log({metric_name: duration})
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||||
@@ -26,6 +26,48 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
|||||||
class PatchManager:
|
class PatchManager:
|
||||||
"""Manages the application of patches during the model loading process."""
|
"""Manages the application of patches during the model loading process."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_pre_config_load_patches(cfg: DictDefault):
|
||||||
|
"""
|
||||||
|
Apply patches that must be set up before config loading.
|
||||||
|
This is for patches that intercept remote code loading from HuggingFace,
|
||||||
|
which needs to be in place before AutoConfig.from_pretrained() is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration dictionary with model and training settings.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(cfg, "base_model_config")
|
||||||
|
and cfg.base_model_config
|
||||||
|
and "kimi-linear" in cfg.base_model_config.lower()
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_config()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
|
||||||
|
"""
|
||||||
|
Apply patches that must be set up before tokenizer loading.
|
||||||
|
This is for patches that intercept remote code loading from HuggingFace,
|
||||||
|
which needs to be in place before AutoTokenizer.from_pretrained() is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration dictionary with model and training settings.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(cfg, "tokenizer_config")
|
||||||
|
and cfg.tokenizer_config
|
||||||
|
and "kimi-linear" in cfg.tokenizer_config.lower()
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_tokenizer()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -96,6 +138,7 @@ class PatchManager:
|
|||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
self._apply_unsloth_patches(model)
|
self._apply_unsloth_patches(model)
|
||||||
self._apply_lora_kernel_patch(model)
|
self._apply_lora_kernel_patch(model)
|
||||||
|
self._apply_scaling_softmax_patch(model)
|
||||||
|
|
||||||
def _apply_flash_attention_patches(self):
|
def _apply_flash_attention_patches(self):
|
||||||
"""Apply patches related to Flash Attention."""
|
"""Apply patches related to Flash Attention."""
|
||||||
@@ -157,12 +200,6 @@ class PatchManager:
|
|||||||
|
|
||||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||||
if self.cfg.sample_packing:
|
|
||||||
from axolotl.core.attention.flex_block_mask import (
|
|
||||||
patch_create_causal_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_create_causal_mask(self.cfg.model_config_type)
|
|
||||||
|
|
||||||
def _apply_model_specific_patches(self):
|
def _apply_model_specific_patches(self):
|
||||||
"""Apply patches specific to model architectures."""
|
"""Apply patches specific to model architectures."""
|
||||||
@@ -190,6 +227,13 @@ class PatchManager:
|
|||||||
|
|
||||||
apply_mistral_tokenizer_image_patch()
|
apply_mistral_tokenizer_image_patch()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "kimi_linear":
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_model()
|
||||||
|
|
||||||
def _apply_fp8_patches(self):
|
def _apply_fp8_patches(self):
|
||||||
"""Apply patches for FP8 support."""
|
"""Apply patches for FP8 support."""
|
||||||
if self.cfg.fp8:
|
if self.cfg.fp8:
|
||||||
@@ -517,3 +561,16 @@ class PatchManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
patch_apertus_xielu_activation()
|
patch_apertus_xielu_activation()
|
||||||
|
|
||||||
|
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
|
||||||
|
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
|
||||||
|
if self.cfg.scaling_softmax:
|
||||||
|
from axolotl.monkeypatch.scaled_softmax_attn import (
|
||||||
|
patch_scaled_softmax_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_scaled_softmax_attention(
|
||||||
|
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
|
||||||
|
bias=self.cfg.scaling_softmax_bias or 0.0,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ def modify_tokenizer_files(
|
|||||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||||
"""Load and configure the tokenizer based on the provided config."""
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
|
|
||||||
|
# Apply patches that need to be in place before tokenizer loading
|
||||||
|
from axolotl.loaders.patch_manager import PatchManager
|
||||||
|
|
||||||
|
PatchManager.apply_pre_tokenizer_load_patches(cfg)
|
||||||
|
|
||||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||||
"""Load mistral-common tokenizer"""
|
"""Load mistral-common tokenizer"""
|
||||||
from axolotl.utils.mistral import HFMistralTokenizer
|
from axolotl.utils.mistral import HFMistralTokenizer
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Type
|
|||||||
|
|
||||||
import addict
|
import addict
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -79,7 +80,11 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
and hasattr(model_config, "vision_config")
|
and hasattr(model_config, "vision_config")
|
||||||
and hasattr(model_config.vision_config, "image_size")
|
and hasattr(model_config.vision_config, "image_size")
|
||||||
):
|
):
|
||||||
cfg.image_size = model_config.vision_config.image_size
|
image_size = model_config.vision_config.image_size
|
||||||
|
if isinstance(image_size, list):
|
||||||
|
cfg.image_size = tuple(image_size)
|
||||||
|
else:
|
||||||
|
cfg.image_size = image_size
|
||||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
@@ -149,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
|||||||
This function determines the appropriate model config source, loads it, applies any
|
This function determines the appropriate model config source, loads it, applies any
|
||||||
necessary overrides, and validates it for compatibility with the `axolotl` config.
|
necessary overrides, and validates it for compatibility with the `axolotl` config.
|
||||||
|
|
||||||
|
If `cfg.cls_model_config` is set, a custom config class from transformers will be
|
||||||
|
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
@@ -170,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
|||||||
if cfg.num_labels:
|
if cfg.num_labels:
|
||||||
# num_labels is used to initialize classifier models
|
# num_labels is used to initialize classifier models
|
||||||
config_kwargs["num_labels"] = cfg.num_labels
|
config_kwargs["num_labels"] = cfg.num_labels
|
||||||
|
|
||||||
|
config_cls = AutoConfig
|
||||||
|
if cfg.cls_model_config:
|
||||||
|
config_cls = getattr(transformers, cfg.cls_model_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_config = AutoConfig.from_pretrained(
|
model_config = config_cls.from_pretrained(
|
||||||
model_config_name,
|
model_config_name,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**config_kwargs,
|
**config_kwargs,
|
||||||
|
|||||||
@@ -75,3 +75,33 @@ def patch_parallelism_config():
|
|||||||
|
|
||||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_cp():
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
def patched_prepare_cp(self, *args):
|
||||||
|
if self.parallelism_config.cp_backend == "deepspeed":
|
||||||
|
return args
|
||||||
|
|
||||||
|
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||||
|
from torch.distributed.tensor.experimental import context_parallel
|
||||||
|
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||||
|
|
||||||
|
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||||
|
set_rotate_method(cp_comm_strategy)
|
||||||
|
|
||||||
|
self._cp_context = functools.partial(
|
||||||
|
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||||
|
)
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.nn.Module):
|
||||||
|
_attach_context_parallel_hooks(arg)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
Accelerator._prepare_cp = patched_prepare_cp
|
||||||
|
|||||||
98
src/axolotl/monkeypatch/loss/dft.py
Normal file
98
src/axolotl/monkeypatch/loss/dft.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Dynamic Fine-Tuning (DFT) loss implementation"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def selective_log_softmax(logits, index):
|
||||||
|
"""Memory-efficient log_softmax -> gather"""
|
||||||
|
if logits.dtype in [torch.float32, torch.float64]:
|
||||||
|
selected_logits = torch.gather(
|
||||||
|
logits, dim=-1, index=index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||||
|
per_token_logps = selected_logits - logsumexp_values
|
||||||
|
else:
|
||||||
|
per_token_logps = []
|
||||||
|
for row_logits, row_labels in zip(logits, index, strict=True):
|
||||||
|
row_logps = F.log_softmax(row_logits, dim=-1)
|
||||||
|
row_per_token_logps = row_logps.gather(
|
||||||
|
dim=-1, index=row_labels.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
per_token_logps.append(row_per_token_logps)
|
||||||
|
per_token_logps = torch.stack(per_token_logps)
|
||||||
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
|
def get_dft_loss(ignore_index: int = -100):
|
||||||
|
"""Creates DFT loss function"""
|
||||||
|
|
||||||
|
def for_causal_lm_dft_loss(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
vocab_size: int = None,
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""DFT loss: -exp(logprobs).detach() * logprobs"""
|
||||||
|
if shift_labels is None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
shift_labels = shift_labels.to(logits.device)
|
||||||
|
|
||||||
|
# Create loss mask
|
||||||
|
loss_mask = shift_labels != ignore_index
|
||||||
|
shift_labels_masked = shift_labels.clone()
|
||||||
|
shift_labels_masked[~loss_mask] = 0
|
||||||
|
|
||||||
|
# Compute log probabilities
|
||||||
|
logprobs = selective_log_softmax(logits, shift_labels_masked)
|
||||||
|
|
||||||
|
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||||
|
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||||
|
|
||||||
|
# Sum over valid tokens and normalize
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
num_items_in_batch = loss_mask.sum()
|
||||||
|
|
||||||
|
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return for_causal_lm_dft_loss
|
||||||
|
|
||||||
|
|
||||||
|
def dft_loss(outputs, labels, num_items_in_batch=None):
|
||||||
|
"""DFT loss compatible with Trainer.compute_loss_func signature.
|
||||||
|
|
||||||
|
This function is designed to be passed to Trainer's compute_loss_func parameter.
|
||||||
|
"""
|
||||||
|
ignore_index = -100
|
||||||
|
|
||||||
|
# Shift labels for causal LM
|
||||||
|
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_labels = shift_labels.to(outputs.logits.device)
|
||||||
|
|
||||||
|
# Create loss mask
|
||||||
|
loss_mask = shift_labels != ignore_index
|
||||||
|
shift_labels_masked = shift_labels.clone()
|
||||||
|
shift_labels_masked[~loss_mask] = 0
|
||||||
|
|
||||||
|
# Compute log probabilities
|
||||||
|
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
|
||||||
|
|
||||||
|
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||||
|
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||||
|
|
||||||
|
# Sum over valid tokens and normalize
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
num_items_in_batch = loss_mask.sum()
|
||||||
|
|
||||||
|
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
Kimi-Linear configuration.
|
||||||
|
|
||||||
|
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py
|
||||||
|
Revision: 6e163f3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class KimiLinearConfig(PretrainedConfig):
|
||||||
|
model_type = "kimi_linear"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_type="kimi_linear",
|
||||||
|
vocab_size=163840,
|
||||||
|
hidden_size=4096,
|
||||||
|
head_dim=None,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
moe_intermediate_size: Optional[int] = None,
|
||||||
|
moe_renormalize: bool = True,
|
||||||
|
moe_router_activation_func: str = "sigmoid",
|
||||||
|
num_experts: Optional[int] = None,
|
||||||
|
num_experts_per_token: Optional[int] = None,
|
||||||
|
num_shared_experts: int = 0,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
first_k_dense_replace: int = 0,
|
||||||
|
moe_layer_freq: int = 1,
|
||||||
|
use_grouped_topk: bool = True,
|
||||||
|
num_expert_group: int = 1,
|
||||||
|
topk_group: int = 1,
|
||||||
|
q_lora_rank: Optional[int] = None,
|
||||||
|
kv_lora_rank: Optional[int] = None,
|
||||||
|
qk_nope_head_dim: Optional[int] = None,
|
||||||
|
qk_rope_head_dim: Optional[int] = None,
|
||||||
|
v_head_dim: Optional[int] = None,
|
||||||
|
mla_use_nope: Optional[bool] = False,
|
||||||
|
num_nextn_predict_layers: int = 0,
|
||||||
|
linear_attn_config: Optional[dict] = None,
|
||||||
|
router_aux_loss_coef: float = 0.01,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = (
|
||||||
|
head_dim if head_dim is not None else hidden_size // num_attention_heads
|
||||||
|
)
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.mla_use_nope = mla_use_nope
|
||||||
|
# moe config
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.num_experts_per_token = num_experts_per_token
|
||||||
|
self.moe_renormalize = moe_renormalize
|
||||||
|
self.num_shared_experts = num_shared_experts
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.moe_router_activation_func = moe_router_activation_func
|
||||||
|
assert self.moe_router_activation_func in ("softmax", "sigmoid")
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
|
||||||
|
if linear_attn_config is not None:
|
||||||
|
assert linear_attn_config["kda_layers"] is not None
|
||||||
|
assert linear_attn_config["full_attn_layers"] is not None
|
||||||
|
self.linear_attn_config = linear_attn_config
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mla(self):
|
||||||
|
return (
|
||||||
|
self.q_lora_rank is not None
|
||||||
|
or self.kv_lora_rank is not None
|
||||||
|
or self.qk_nope_head_dim is not None
|
||||||
|
or self.qk_rope_head_dim is not None
|
||||||
|
or self.v_head_dim is not None
|
||||||
|
or self.mla_use_nope is True
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_moe(self):
|
||||||
|
return self.num_experts is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_linear_attn(self) -> bool:
|
||||||
|
return not (
|
||||||
|
self.linear_attn_config is None
|
||||||
|
or (
|
||||||
|
isinstance(self.linear_attn_config, dict)
|
||||||
|
and self.linear_attn_config["kda_layers"] is not None
|
||||||
|
and len(self.linear_attn_config["kda_layers"]) == 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_kda_layer(self, layer_idx: int):
|
||||||
|
return (
|
||||||
|
self.linear_attn_config is not None
|
||||||
|
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
|
||||||
|
)
|
||||||
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
|||||||
|
import importlib.resources
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||||
|
|
||||||
|
|
||||||
|
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||||
|
"""
|
||||||
|
Gets the absolute path to a patch file using importlib.resources.files.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.resources.files(package_dot_path) / filename
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_module(module_name: str, filename: str):
|
||||||
|
"""Helper to load a local module if not already loaded."""
|
||||||
|
if module_name in sys.modules:
|
||||||
|
return sys.modules[module_name]
|
||||||
|
|
||||||
|
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
|
||||||
|
if patch_path and patch_path.exists():
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, patch_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_get_class_in_module():
|
||||||
|
"""
|
||||||
|
Core patch function that hijacks Transformers' dynamic module loading.
|
||||||
|
"""
|
||||||
|
from transformers.dynamic_module_utils import get_class_in_module
|
||||||
|
|
||||||
|
if hasattr(get_class_in_module, "_axolotl_patched"):
|
||||||
|
return
|
||||||
|
|
||||||
|
original_get_class_in_module = get_class_in_module
|
||||||
|
|
||||||
|
# Mapping of module path patterns to (module_name, filename)
|
||||||
|
KIMI_MODULE_MAP = {
|
||||||
|
"configuration_kimi": ("configuration_kimi", "configuration_kimi.py"),
|
||||||
|
"modeling_kimi": ("modeling_kimi", "modeling_kimi.py"),
|
||||||
|
"tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def patched_get_class_in_module(class_name, module_path, **kwargs):
|
||||||
|
"""Patched version that returns our local modules instead of remote ones."""
|
||||||
|
for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():
|
||||||
|
if pattern in module_path:
|
||||||
|
module = _load_local_module(module_name, filename)
|
||||||
|
if module:
|
||||||
|
return getattr(module, class_name)
|
||||||
|
break # Pattern matched but file not found, fall through
|
||||||
|
|
||||||
|
return original_get_class_in_module(class_name, module_path, **kwargs)
|
||||||
|
|
||||||
|
import transformers.dynamic_module_utils
|
||||||
|
|
||||||
|
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||||
|
patched_get_class_in_module._axolotl_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
def patch_kimi():
|
||||||
|
"""
|
||||||
|
Apply all Kimi patches.
|
||||||
|
Must be called BEFORE loading config/tokenizer/model.
|
||||||
|
"""
|
||||||
|
_patch_get_class_in_module()
|
||||||
|
LOG.info("Kimi patches applied successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
# Keep these for backward compatibility if needed
|
||||||
|
patch_kimi_config = patch_kimi
|
||||||
|
patch_kimi_tokenizer = patch_kimi
|
||||||
|
patch_kimi_model = patch_kimi
|
||||||
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
"""
|
||||||
|
Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.
|
||||||
|
|
||||||
|
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py
|
||||||
|
Revision: 919416f
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
from tokenizers import AddedToken
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
||||||
|
|
||||||
|
|
||||||
|
class TikTokenTokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
|
||||||
|
|
||||||
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||||
|
this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`):
|
||||||
|
The path to the Tiktoken model file.
|
||||||
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
|
||||||
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||||
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead. The second to last item in special_tokens.
|
||||||
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
additional_special_tokens (list of `str`, *optional*):
|
||||||
|
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
|
||||||
|
skipped when decoding if `skip_special_tokens` is set to `True`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
special_tokens: Dict[str, int]
|
||||||
|
|
||||||
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
pat_str = "|".join(
|
||||||
|
[
|
||||||
|
r"""[\p{Han}]+""",
|
||||||
|
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||||
|
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||||
|
r"""\p{N}{1,3}""",
|
||||||
|
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
|
||||||
|
r"""\s*[\r\n]+""",
|
||||||
|
r"""\s+(?!\S)""",
|
||||||
|
r"""\s+""",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107
|
||||||
|
eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107
|
||||||
|
unk_token: Union[str, AddedToken, None] = None,
|
||||||
|
pad_token: Union[str, AddedToken, None] = None,
|
||||||
|
additional_special_tokens: List[str] = None,
|
||||||
|
added_tokens_decoder: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert os.path.isfile(vocab_file), vocab_file
|
||||||
|
|
||||||
|
if additional_special_tokens is None:
|
||||||
|
additional_special_tokens = [
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_user|>",
|
||||||
|
"<|im_assistant|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"[EOT]",
|
||||||
|
"<|im_system|>",
|
||||||
|
"<|im_middle|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
special_tokens_mapping = {
|
||||||
|
i: added_tokens_decoder[i].content for i in added_tokens_decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
||||||
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
self.special_tokens = {
|
||||||
|
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
|
||||||
|
for i in range(
|
||||||
|
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.model = tiktoken.Encoding(
|
||||||
|
name=Path(vocab_file).name,
|
||||||
|
pat_str=self.pat_str,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=self.special_tokens,
|
||||||
|
)
|
||||||
|
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
||||||
|
|
||||||
|
self.n_words: int = self.model.n_vocab
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.bos_id: int = self.special_tokens[str(bos_token)]
|
||||||
|
self.eos_id: int = self.special_tokens[str(eos_token)]
|
||||||
|
logger.info(
|
||||||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pad_id: int = self.special_tokens[str(pad_token)]
|
||||||
|
self.unk_id: int = self.special_tokens[str(unk_token)]
|
||||||
|
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
|
||||||
|
self.decoder = {}
|
||||||
|
for i in range(self.n_words):
|
||||||
|
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
||||||
|
decoding = "".join(
|
||||||
|
[
|
||||||
|
self.byte_encoder[ord(char)]
|
||||||
|
for char in self.model.decode_single_token_bytes(i).decode(
|
||||||
|
"latin-1"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.decoder[i] = decoding
|
||||||
|
|
||||||
|
self.encoder = {}
|
||||||
|
for i in range(self.n_words):
|
||||||
|
if i in self.decoder:
|
||||||
|
self.encoder[self.decoder[i]] = i
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.all_special_ids_set = set(self.all_special_ids)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, text: str, allow_special_tokens: bool = True, **kwargs
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input string to be encoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of token IDs.
|
||||||
|
"""
|
||||||
|
# If there are other args, we should call super().encode because there are a lot of code
|
||||||
|
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
|
||||||
|
# NOTE: our encode method is not compatible with the super().encode method,
|
||||||
|
# e.g. split_special_tokens' default is True in our encode method.
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
# logger.warning(f"Calling super().encode with {kwargs}")
|
||||||
|
return super().encode(text, **kwargs)
|
||||||
|
|
||||||
|
assert type(text) is str
|
||||||
|
|
||||||
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
|
# pyo3_runtime.PanicException.
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
# https://github.com/openai/tiktoken/issues/195
|
||||||
|
# Here we iterate over subsequences and split if we exceed the limit
|
||||||
|
# of max consecutive non-whitespace or whitespace characters.
|
||||||
|
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
|
texts = self.pre_tokenizer_process(text)
|
||||||
|
|
||||||
|
all_substrs = []
|
||||||
|
for text in texts:
|
||||||
|
substrs = (
|
||||||
|
substr
|
||||||
|
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||||
|
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||||
|
text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_substrs.extend(substrs)
|
||||||
|
|
||||||
|
t: List[int] = []
|
||||||
|
for substr in all_substrs:
|
||||||
|
if allow_special_tokens:
|
||||||
|
t.extend(
|
||||||
|
# we should consider special token as a common token
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
allowed_special="all",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
t.extend(
|
||||||
|
# we should consider special token as a common token
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (List[int]): The list of token IDs to be decoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The decoded string.
|
||||||
|
"""
|
||||||
|
# If there are other args, we should call super().decode because there are a lot of code
|
||||||
|
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
return super().decode(token_ids, **kwargs)
|
||||||
|
|
||||||
|
if type(token_ids) is int:
|
||||||
|
token_ids = [token_ids]
|
||||||
|
|
||||||
|
return self.model.decode(cast(List[int], token_ids))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_whitespaces_or_nonwhitespaces(
|
||||||
|
s: str, max_consecutive_slice_len: int
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||||
|
consecutive whitespaces or consecutive non-whitespaces.
|
||||||
|
"""
|
||||||
|
current_slice_len = 0
|
||||||
|
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||||
|
slice_start = 0
|
||||||
|
|
||||||
|
for i in range(len(s)):
|
||||||
|
is_now_space = s[i].isspace()
|
||||||
|
|
||||||
|
if current_slice_is_space ^ is_now_space:
|
||||||
|
current_slice_len = 1
|
||||||
|
current_slice_is_space = is_now_space
|
||||||
|
else:
|
||||||
|
current_slice_len += 1
|
||||||
|
if current_slice_len > max_consecutive_slice_len:
|
||||||
|
yield s[slice_start:i]
|
||||||
|
slice_start = i
|
||||||
|
current_slice_len = 1
|
||||||
|
yield s[slice_start:]
|
||||||
|
|
||||||
|
def pre_tokenizer_process(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
pre-tokenizes the input text into a list of tokens.
|
||||||
|
This method is used to split the input text into smaller chunks for internal processing.
|
||||||
|
"""
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
|
return self.encoder
|
||||||
|
|
||||||
|
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||||
|
return [self.decoder[t] for t in self.encode(text)]
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token: str) -> int:
|
||||||
|
return self.encoder.get(token, self.unk_id)
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
|
return self.decoder.get(index)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_up_tokenization(out_string: str) -> str:
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
text = "".join(tokens)
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||||
|
"utf-8", "replace"
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def save_vocabulary(
|
||||||
|
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||||
|
) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
raise ValueError(
|
||||||
|
f"vocabulary path ({save_directory}) should be a directory"
|
||||||
|
)
|
||||||
|
out_vocab_file = os.path.join(
|
||||||
|
save_directory,
|
||||||
|
(filename_prefix + "-" if filename_prefix else "")
|
||||||
|
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||||
|
out_vocab_file
|
||||||
|
) and os.path.isfile(self.vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation,
|
||||||
|
tools: Optional[list[dict]] = None,
|
||||||
|
tokenize: bool = True,
|
||||||
|
add_generation_prompt: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
tools = deep_sort_dict(tools)
|
||||||
|
return super().apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
tools=tools,
|
||||||
|
tokenize=tokenize,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_sort_dict(obj: Any) -> Any:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [deep_sort_dict(item) for item in obj]
|
||||||
|
return obj
|
||||||
141
src/axolotl/monkeypatch/scaled_softmax_attn.py
Normal file
141
src/axolotl/monkeypatch/scaled_softmax_attn.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Scaled Softmax (SSMax) attention patch using FlexAttention.
|
||||||
|
SSMax: softmax(scores * s * log(n) + b) where n is the position index
|
||||||
|
Ref: https://arxiv.org/abs/2501.19399
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
from transformers.integrations.flex_attention import (
|
||||||
|
compile_friendly_flex_attention,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
FLEX_ATTENTION_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
FLEX_ATTENTION_AVAILABLE = False
|
||||||
|
BlockMask = None
|
||||||
|
|
||||||
|
_ssmax_config = {}
|
||||||
|
|
||||||
|
|
||||||
|
def patch_scaled_softmax_attention(
|
||||||
|
scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None
|
||||||
|
):
|
||||||
|
"""Patch attention to apply SSMax via FlexAttention score_mod."""
|
||||||
|
global _ssmax_config
|
||||||
|
|
||||||
|
if not FLEX_ATTENTION_AVAILABLE:
|
||||||
|
raise RuntimeError("SSMax requires FlexAttention.")
|
||||||
|
|
||||||
|
_ssmax_config["ssmax_s"] = scaling_factor_init
|
||||||
|
_ssmax_config["ssmax_b"] = bias
|
||||||
|
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
if "flex_attention" in ALL_ATTENTION_FUNCTIONS:
|
||||||
|
_ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"]
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward
|
||||||
|
LOG.info(
|
||||||
|
f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("flex_attention not found. Ensure flex_attention: true is set.")
|
||||||
|
|
||||||
|
|
||||||
|
def ssmax_flex_attention_forward(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask,
|
||||||
|
scaling: float | None = None,
|
||||||
|
softcap: float | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
"""FlexAttention forward with SSMax: score * (s * log(n) + b)."""
|
||||||
|
|
||||||
|
if kwargs.get("dropout", 0.0) > 0:
|
||||||
|
raise ValueError("flex_attention does not support dropout")
|
||||||
|
|
||||||
|
ssmax_s = _ssmax_config.get("ssmax_s", 0.43)
|
||||||
|
ssmax_b = _ssmax_config.get("ssmax_b", 0.0)
|
||||||
|
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
position_ids_flat = position_ids.view(-1) if position_ids is not None else None
|
||||||
|
|
||||||
|
block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None
|
||||||
|
score_mask = None if block_mask else attention_mask
|
||||||
|
|
||||||
|
if score_mask is not None:
|
||||||
|
score_mask = score_mask[:, :, :, : key.shape[-2]]
|
||||||
|
|
||||||
|
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
|
||||||
|
"""
|
||||||
|
Apply SSMax scaling: score * (s * log(n) + b)
|
||||||
|
where n is the relative position within each packed sequence.
|
||||||
|
"""
|
||||||
|
if position_ids_flat is not None:
|
||||||
|
relative_pos = position_ids_flat[q_idx]
|
||||||
|
n = (relative_pos + 1).float()
|
||||||
|
else:
|
||||||
|
n = (q_idx + 1).float()
|
||||||
|
|
||||||
|
n = torch.clamp(n, min=2.0)
|
||||||
|
|
||||||
|
ssmax_scale = ssmax_s * torch.log(n) + ssmax_b
|
||||||
|
score = score * ssmax_scale
|
||||||
|
|
||||||
|
if softcap is not None:
|
||||||
|
score = softcap * torch.tanh(score / softcap)
|
||||||
|
|
||||||
|
if score_mask is not None:
|
||||||
|
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
enable_gqa = True
|
||||||
|
if (query.shape[1] & (query.shape[1] - 1)) != 0:
|
||||||
|
key = repeat_kv(key, query.shape[1] // key.shape[1])
|
||||||
|
value = repeat_kv(value, query.shape[1] // value.shape[1])
|
||||||
|
enable_gqa = False
|
||||||
|
|
||||||
|
return_lse = query.device.type != "cpu"
|
||||||
|
flex_output = compile_friendly_flex_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
score_mod=score_mod,
|
||||||
|
block_mask=block_mask,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
scale=scaling,
|
||||||
|
kernel_options=kwargs.get("kernel_options"),
|
||||||
|
return_lse=return_lse,
|
||||||
|
training=module.training,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
attention_output, lse = flex_output
|
||||||
|
lse = lse.to(value.dtype)
|
||||||
|
else:
|
||||||
|
attention_output, lse = flex_output, None
|
||||||
|
|
||||||
|
return attention_output.transpose(1, 2).contiguous(), lse
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_scaled_softmax_attention():
|
||||||
|
"""Restore the original FlexAttention function."""
|
||||||
|
global _ssmax_config
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
if "original_flex_fn" in _ssmax_config:
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"]
|
||||||
|
_ssmax_config.clear()
|
||||||
|
LOG.info("Unpatched flex_attention, restored original")
|
||||||
@@ -8,6 +8,7 @@ from PIL.Image import Resampling
|
|||||||
from torch import Tensor, zeros_like
|
from torch import Tensor, zeros_like
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
from transformers.image_utils import load_image
|
from transformers.image_utils import load_image
|
||||||
|
from transformers.models.internvl import InternVLProcessor
|
||||||
from transformers.models.smolvlm import SmolVLMProcessor
|
from transformers.models.smolvlm import SmolVLMProcessor
|
||||||
from transformers.models.voxtral import VoxtralProcessor
|
from transformers.models.voxtral import VoxtralProcessor
|
||||||
|
|
||||||
@@ -454,6 +455,37 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
|
|||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for InternVL"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
|
||||||
|
if not hasattr(processor, "image_ids"):
|
||||||
|
raise ValueError("'image_ids' missing from InternVL Processor.")
|
||||||
|
|
||||||
|
self.image_token_ids = processor.image_ids
|
||||||
|
|
||||||
|
def process_labels(self, input_ids):
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
|
for ids in self.image_token_ids:
|
||||||
|
labels[labels == ids] = -100
|
||||||
|
|
||||||
|
# Note: Check if need to mask 'video_token' as it gets converted to
|
||||||
|
# image patches during media processing
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def get_processing_strategy(
|
def get_processing_strategy(
|
||||||
processor: ProcessorMixin,
|
processor: ProcessorMixin,
|
||||||
chat_template,
|
chat_template,
|
||||||
@@ -501,6 +533,11 @@ def get_processing_strategy(
|
|||||||
**processing_kwargs,
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(processor, InternVLProcessor):
|
||||||
|
return InternVLProcessingStrategy(
|
||||||
|
**processing_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# llama3_2_vision, llama4, llava
|
# llama3_2_vision, llama4, llava
|
||||||
# mistral_v7_tekken, pixtral, lfm2vl
|
# mistral_v7_tekken, pixtral, lfm2vl
|
||||||
return ProcessingStrategy(
|
return ProcessingStrategy(
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ def is_opentelemetry_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_trackio_available():
|
||||||
|
return importlib.util.find_spec("trackio") is not None
|
||||||
|
|
||||||
|
|
||||||
def get_pytorch_version() -> tuple[int, int, int]:
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Get Pytorch version as a tuple of (major, minor, patch).
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
|||||||
248
src/axolotl/utils/callbacks/swanlab.py
Normal file
248
src/axolotl/utils/callbacks/swanlab.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""Callbacks for SwanLab integration"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from shutil import copyfile
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSwanLabCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Lightweight SwanLab callback that directly logs metrics without using
|
||||||
|
SwanLab's transformers integration (which requires omegaconf).
|
||||||
|
|
||||||
|
This avoids the antlr4 version conflict between omegaconf and axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._initialized = False
|
||||||
|
self.swanlab = None
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""Lazy initialization of SwanLab"""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
self.swanlab = swanlab
|
||||||
|
|
||||||
|
# Check if SwanLab run is initialized
|
||||||
|
if swanlab.get_run() is None:
|
||||||
|
LOG.warning("SwanLab run is not initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
LOG.info("CustomSwanLabCallback initialized successfully")
|
||||||
|
except ImportError:
|
||||||
|
LOG.error("SwanLab is not installed")
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the beginning of training"""
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return control
|
||||||
|
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
return control
|
||||||
|
|
||||||
|
# Log training configuration
|
||||||
|
try:
|
||||||
|
self.swanlab.config.update(
|
||||||
|
{
|
||||||
|
"train_batch_size": args.per_device_train_batch_size,
|
||||||
|
"eval_batch_size": args.per_device_eval_batch_size,
|
||||||
|
"learning_rate": args.learning_rate,
|
||||||
|
"num_train_epochs": args.num_train_epochs,
|
||||||
|
"max_steps": args.max_steps,
|
||||||
|
"warmup_steps": args.warmup_steps,
|
||||||
|
"logging_steps": args.logging_steps,
|
||||||
|
"save_steps": args.save_steps,
|
||||||
|
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
LOG.debug("Training configuration logged to SwanLab")
|
||||||
|
except Exception as err:
|
||||||
|
LOG.warning(f"Failed to log training config: {err}")
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
logs=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called when logging metrics"""
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return control
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
if not self._initialized or logs is None:
|
||||||
|
return control
|
||||||
|
|
||||||
|
# Log metrics to SwanLab
|
||||||
|
try:
|
||||||
|
# Filter out non-numeric values and prepare for logging
|
||||||
|
metrics = {}
|
||||||
|
for key, value in logs.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
# Use step from state
|
||||||
|
metrics[key] = value
|
||||||
|
|
||||||
|
if metrics and state.global_step is not None:
|
||||||
|
self.swanlab.log(metrics, step=state.global_step)
|
||||||
|
except Exception as err:
|
||||||
|
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_train_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Called at the end of training"""
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return control
|
||||||
|
|
||||||
|
if self._initialized:
|
||||||
|
LOG.info("Training completed. SwanLab logs are available.")
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
|
||||||
|
"""Callback to save axolotl config to SwanLab"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
# Check if SwanLab is initialized
|
||||||
|
if swanlab.get_run() is None:
|
||||||
|
LOG.warning(
|
||||||
|
"SwanLab run is not initialized. Please initialize SwanLab before training."
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
|
# Log Axolotl config as artifact
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
|
||||||
|
# Log config file to SwanLab
|
||||||
|
with open(temp_file.name, "r", encoding="utf-8") as config_file:
|
||||||
|
swanlab.log(
|
||||||
|
{
|
||||||
|
"axolotl_config": swanlab.Text(
|
||||||
|
config_file.read(), caption="Axolotl Config"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the SwanLab run under logs."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"SwanLab is not installed. Install it with: pip install swanlab"
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
|
||||||
|
|
||||||
|
# Log DeepSpeed config if available
|
||||||
|
if args.deepspeed:
|
||||||
|
try:
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
delete=False,
|
||||||
|
suffix=".json",
|
||||||
|
prefix="deepspeed_config_",
|
||||||
|
) as temp_file:
|
||||||
|
skip_upload = False
|
||||||
|
if isinstance(args.deepspeed, dict):
|
||||||
|
json.dump(args.deepspeed, temp_file, indent=4)
|
||||||
|
elif isinstance(args.deepspeed, str) and os.path.exists(
|
||||||
|
args.deepspeed
|
||||||
|
):
|
||||||
|
copyfile(args.deepspeed, temp_file.name)
|
||||||
|
else:
|
||||||
|
skip_upload = True
|
||||||
|
|
||||||
|
if not skip_upload:
|
||||||
|
temp_file.flush()
|
||||||
|
with open(
|
||||||
|
temp_file.name, "r", encoding="utf-8"
|
||||||
|
) as ds_config_file:
|
||||||
|
swanlab.log(
|
||||||
|
{
|
||||||
|
"deepspeed_config": swanlab.Text(
|
||||||
|
ds_config_file.read(),
|
||||||
|
caption="DeepSpeed Config",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"The DeepSpeed config has been saved to the SwanLab run under logs."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
|
LOG.warning(
|
||||||
|
f"Error while saving DeepSpeed config to SwanLab: {err}"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return control
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user