Compare commits
1 Commits
uv-fixup
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fccbadb79 |
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,11 +15,6 @@
|
|||||||
<!--- Include details of your testing environment, tests ran to see how -->
|
<!--- Include details of your testing environment, tests ran to see how -->
|
||||||
<!--- your change affects other areas of the code, etc. -->
|
<!--- your change affects other areas of the code, etc. -->
|
||||||
|
|
||||||
## AI Usage Disclaimer
|
|
||||||
|
|
||||||
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
|
||||||
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
|
||||||
|
|
||||||
## Screenshots (if appropriate)
|
## Screenshots (if appropriate)
|
||||||
|
|
||||||
## Types of changes
|
## Types of changes
|
||||||
|
|||||||
145
.github/workflows/base.yml
vendored
145
.github/workflows/base.yml
vendored
@@ -21,12 +21,31 @@ jobs:
|
|||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
|
||||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -34,15 +53,6 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -50,31 +60,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
# - cuda: "129"
|
|
||||||
# cuda_version: 12.9.1
|
|
||||||
# cudnn_version: ""
|
|
||||||
# python_version: "3.12"
|
|
||||||
# pytorch: 2.9.1
|
|
||||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
# dockerfile: "Dockerfile-base"
|
|
||||||
# platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -82,23 +67,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
# cuda_version: 12.8.1
|
# cuda_version: 12.8.1
|
||||||
# cudnn_version: ""
|
# cudnn_version: ""
|
||||||
@@ -125,7 +93,6 @@ jobs:
|
|||||||
axolotlai/axolotl-base
|
axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
@@ -136,7 +103,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/${{ matrix.dockerfile }}
|
file: ./docker/${{ matrix.dockerfile }}
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
@@ -151,12 +117,24 @@ jobs:
|
|||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
env:
|
|
||||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -164,7 +142,6 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -172,39 +149,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
# - cuda: "129"
|
|
||||||
# cuda_version: 12.9.1
|
|
||||||
# cudnn_version: ""
|
|
||||||
# python_version: "3.12"
|
|
||||||
# pytorch: 2.9.1
|
|
||||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
# dockerfile: "Dockerfile-uv-base"
|
|
||||||
# platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -212,23 +156,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "130"
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -240,7 +167,6 @@ jobs:
|
|||||||
axolotlai/axolotl-base-uv
|
axolotlai/axolotl-base-uv
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
@@ -251,7 +177,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/${{ matrix.dockerfile }}
|
file: ./docker/${{ matrix.dockerfile }}
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,9 +12,6 @@ jobs:
|
|||||||
build-deploy:
|
build-deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
|
||||||
run: |
|
|
||||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
|
|||||||
255
.github/workflows/main.yml
vendored
255
.github/workflows/main.yml
vendored
@@ -15,49 +15,37 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64"
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
is_latest: true
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
# - cuda: 129
|
|
||||||
# cuda_version: 12.9.1
|
|
||||||
# python_version: "3.12"
|
|
||||||
# pytorch: 2.9.1
|
|
||||||
# axolotl_extras:
|
|
||||||
# platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -83,7 +71,6 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
@@ -98,77 +85,6 @@ jobs:
|
|||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
build-axolotl-uv:
|
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
is_latest: true
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
runs-on: axolotl-gpu-runner
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Docker metadata
|
|
||||||
id: metadata
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
axolotlai/axolotl-uv
|
|
||||||
tags: |
|
|
||||||
type=ref,event=branch
|
|
||||||
type=pep440,pattern={{version}}
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Login to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
|
||||||
- name: Build and export to Docker
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
build-args: |
|
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
|
||||||
CUDA=${{ matrix.cuda }}
|
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
|
||||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
|
||||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
|
||||||
file: ./docker/Dockerfile-uv
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: |
|
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
|
||||||
|
|
||||||
build-axolotl-cloud:
|
build-axolotl-cloud:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
@@ -176,49 +92,43 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras:
|
||||||
|
is_latest:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64"
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
# - cuda: 129
|
|
||||||
# cuda_version: 12.9.1
|
|
||||||
# python_version: "3.12"
|
|
||||||
# pytorch: 2.9.1
|
|
||||||
# axolotl_extras:
|
|
||||||
# platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -243,7 +153,6 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
@@ -254,73 +163,6 @@ jobs:
|
|||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
build-axolotl-cloud-uv:
|
|
||||||
needs: build-axolotl-uv
|
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
is_latest: true
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
runs-on: axolotl-gpu-runner
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Docker metadata
|
|
||||||
id: metadata
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: |
|
|
||||||
axolotlai/axolotl-cloud-uv
|
|
||||||
tags: |
|
|
||||||
type=ref,event=branch
|
|
||||||
type=pep440,pattern={{version}}
|
|
||||||
- name: Login to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Build
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: ${{ matrix.platforms }}
|
|
||||||
build-args: |
|
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
CUDA=${{ matrix.cuda }}
|
|
||||||
file: ./docker/Dockerfile-cloud-uv
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: |
|
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
|
||||||
|
|
||||||
build-axolotl-cloud-no-tmux:
|
build-axolotl-cloud-no-tmux:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
@@ -328,16 +170,22 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras:
|
||||||
|
is_latest:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
|
||||||
is_latest: true
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest:
|
is_latest:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
@@ -364,7 +212,6 @@ jobs:
|
|||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: linux/amd64,linux/arm64
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
|
|||||||
36
.github/workflows/multi-gpu-e2e.yml
vendored
36
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -19,9 +19,6 @@ concurrency:
|
|||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
env:
|
|
||||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||||
@@ -29,32 +26,27 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras: fbgemm-gpu
|
axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.0
|
||||||
axolotl_extras: "fbgemm-gpu"
|
axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
|
||||||
- cuda: 129
|
|
||||||
cuda_version: 12.9.1
|
|
||||||
python_version: "3.12"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras: "fbgemm-gpu"
|
|
||||||
num_gpus: 2
|
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
axolotl_extras:
|
|
||||||
# axolotl_extras: fbgemm-gpu
|
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -67,7 +59,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -76,8 +68,8 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run -m cicd.multigpu
|
modal run cicd.multigpu
|
||||||
|
|||||||
16
.github/workflows/nightlies.yml
vendored
16
.github/workflows/nightlies.yml
vendored
@@ -12,15 +12,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
- cuda: 126
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -64,15 +64,15 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
- cuda: 126
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,7 +11,6 @@ on:
|
|||||||
- '_quarto.yml'
|
- '_quarto.yml'
|
||||||
- docs/scripts/generate_config_docs.py
|
- docs/scripts/generate_config_docs.py
|
||||||
- src/axolotl/utils/schemas/**.py
|
- src/axolotl/utils/schemas/**.py
|
||||||
- .github/workflows/preview-docs.yml
|
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
checks: write
|
checks: write
|
||||||
@@ -28,10 +27,6 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
if: ${{ !github.event.pull_request.draft }}
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
|
||||||
run: |
|
|
||||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
|
||||||
|
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
6
.github/workflows/pypi.yml
vendored
6
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging==26.0
|
pip3 install wheel packaging==23.2
|
||||||
pip3 install --no-build-isolation -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
@@ -48,9 +48,9 @@ jobs:
|
|||||||
id: tag
|
id: tag
|
||||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||||
|
|
||||||
- name: Update version in VERSION file
|
- name: Update version in setup.py
|
||||||
run: |
|
run: |
|
||||||
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
|
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||||
|
|
||||||
- name: Build a source dist
|
- name: Build a source dist
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
24
.github/workflows/tests-nightly.yml
vendored
24
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.7.1", "2.8.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -37,7 +37,7 @@ jobs:
|
|||||||
id: hf-cache-restore-s3
|
id: hf-cache-restore-s3
|
||||||
run: |
|
run: |
|
||||||
mkdir -p /home/runner/.cache/huggingface/hub
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -48,7 +48,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -99,17 +99,17 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
- cuda: 126
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.8.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
@@ -123,7 +123,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -148,10 +148,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
- cuda: 126
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.7.1
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
@@ -165,7 +165,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
115
.github/workflows/tests.yml
vendored
115
.github/workflows/tests.yml
vendored
@@ -54,13 +54,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||||
# exclude:
|
|
||||||
# - python_version: "3.14"
|
|
||||||
# pytorch_version: "2.8.0"
|
|
||||||
# - python_version: "3.14"
|
|
||||||
# pytorch_version: "2.9.1"
|
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -71,13 +66,12 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore Cache from S3
|
# - name: Restore Cache from S3
|
||||||
id: hf-cache-restore-s3
|
# id: hf-cache-restore-s3
|
||||||
run: |
|
# run: |
|
||||||
mkdir -p ~/.cache/huggingface/hub
|
# mkdir -p ~/.cache/huggingface/hub
|
||||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
ls -ltr ~/.cache/huggingface/hub/
|
#
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -87,7 +81,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -115,15 +109,12 @@ jobs:
|
|||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
- name: Pre-Download dataset fixture
|
||||||
run: |
|
run: |
|
||||||
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Show HF cache
|
|
||||||
run: hf cache ls
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
df -h
|
df -h
|
||||||
@@ -131,9 +122,6 @@ jobs:
|
|||||||
df -h
|
df -h
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Show HF cache
|
|
||||||
run: hf cache ls
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
@@ -149,13 +137,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||||
# exclude:
|
|
||||||
# - python_version: "3.14"
|
|
||||||
# pytorch_version: "2.8.0"
|
|
||||||
# - python_version: "3.14"
|
|
||||||
# pytorch_version: "2.9.1"
|
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -166,13 +149,12 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore Cache from S3
|
# - name: Restore Cache from S3
|
||||||
id: hf-cache-restore-s3
|
# id: hf-cache-restore-s3
|
||||||
run: |
|
# run: |
|
||||||
mkdir -p ~/.cache/huggingface/hub
|
# mkdir -p ~/.cache/huggingface/hub
|
||||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
ls -ltr ~/.cache/huggingface/hub/
|
#
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -182,7 +164,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -210,19 +192,16 @@ jobs:
|
|||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Show HF cache
|
- name: Show HF cache
|
||||||
run: hf cache ls
|
run: hf cache scan
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: Show HF cache
|
|
||||||
run: hf cache ls
|
|
||||||
|
|
||||||
gate-skip-e2e:
|
gate-skip-e2e:
|
||||||
needs: [pre-commit]
|
needs: [pre-commit, pytest, pytest-sdist]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
skip: ${{ steps.compute.outputs.skip }}
|
skip: ${{ steps.compute.outputs.skip }}
|
||||||
@@ -258,16 +237,16 @@ jobs:
|
|||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 130
|
- cuda: 128
|
||||||
cuda_version: 13.0.0
|
cuda_version: 12.8.1
|
||||||
python_version: "3.12"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.8.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
dockerfile: "Dockerfile-uv.jinja"
|
||||||
@@ -281,7 +260,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -313,6 +292,18 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
# - cuda: 128
|
||||||
|
# cuda_version: 12.8.1
|
||||||
|
# python_version: "3.11"
|
||||||
|
# pytorch: 2.7.1
|
||||||
|
# num_gpus: 1
|
||||||
|
# axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -323,19 +314,7 @@ jobs:
|
|||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.0
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 130
|
|
||||||
cuda_version: 13.0.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.1
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
@@ -348,7 +327,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -375,10 +354,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 129
|
- cuda: 126
|
||||||
cuda_version: 12.9.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
@@ -391,7 +370,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==1.3.0.post1 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.14.10
|
rev: v0.14.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.19.1
|
rev: v1.19.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ datasets:
|
|||||||
| --------------------------------- | -------------------------- | ----------------------------------- |
|
| --------------------------------- | -------------------------- | ----------------------------------- |
|
||||||
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
||||||
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
||||||
| `dataset_num_proc` | `4` | Number of preprocessing processes |
|
| `dataset_processes` | `4` | Number of preprocessing processes |
|
||||||
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
||||||
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
||||||
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
||||||
|
|||||||
@@ -39,6 +39,7 @@
|
|||||||
# type: # linear | dynamic
|
# type: # linear | dynamic
|
||||||
# factor: # float
|
# factor: # float
|
||||||
|
|
||||||
|
|
||||||
# # Whether you are training a 4-bit GPTQ quantized model
|
# # Whether you are training a 4-bit GPTQ quantized model
|
||||||
# gptq: true
|
# gptq: true
|
||||||
# gptq_groupsize: 128 # group size
|
# gptq_groupsize: 128 # group size
|
||||||
@@ -106,7 +107,7 @@
|
|||||||
# push_dataset_to_hub: # repo path
|
# push_dataset_to_hub: # repo path
|
||||||
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# # if not set.
|
# # if not set.
|
||||||
# dataset_num_proc: # defaults to os.cpu_count() if not set
|
# dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
# # push checkpoints to hub
|
# # push checkpoints to hub
|
||||||
# hub_model_id: # repo path to push finetuned model
|
# hub_model_id: # repo path to push finetuned model
|
||||||
# # how to push checkpoints to hub
|
# # how to push checkpoints to hub
|
||||||
@@ -223,6 +224,9 @@
|
|||||||
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
|
# # Save model as safetensors (require safetensors package)
|
||||||
|
# save_safetensors:
|
||||||
|
|
||||||
# # Whether to mask out or include the human's prompt from the training labels
|
# # Whether to mask out or include the human's prompt from the training labels
|
||||||
# train_on_inputs: false
|
# train_on_inputs: false
|
||||||
# # Group similarly sized data to minimize padding.
|
# # Group similarly sized data to minimize padding.
|
||||||
@@ -348,6 +352,8 @@
|
|||||||
# # Allow overwrite yml config using from cli
|
# # Allow overwrite yml config using from cli
|
||||||
# strict:
|
# strict:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
base_model: ${BASE_MODEL}
|
base_model: ${BASE_MODEL}
|
||||||
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
|
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
|
||||||
base_model_config: ${BASE_MODEL_CONFIG}
|
base_model_config: ${BASE_MODEL_CONFIG}
|
||||||
@@ -406,7 +412,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
|
|||||||
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
|
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
|
||||||
dataset_prepared_path: ${DATASET_PREPARED_PATH}
|
dataset_prepared_path: ${DATASET_PREPARED_PATH}
|
||||||
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
|
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
|
||||||
dataset_num_proc: ${DATASET_NUM_PROC}
|
dataset_processes: ${DATASET_PROCESSES}
|
||||||
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
|
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
|
||||||
hub_model_id: ${HUB_MODEL_ID}
|
hub_model_id: ${HUB_MODEL_ID}
|
||||||
hub_strategy: ${HUB_STRATEGY}
|
hub_strategy: ${HUB_STRATEGY}
|
||||||
@@ -506,6 +512,7 @@ profiler_steps: ${PROFILER_STEPS}
|
|||||||
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
|
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
|
||||||
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
|
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
|
||||||
|
|
||||||
|
save_safetensors: ${SAVE_SAFETENSORS}
|
||||||
train_on_inputs: ${TRAIN_ON_INPUTS}
|
train_on_inputs: ${TRAIN_ON_INPUTS}
|
||||||
group_by_length: ${GROUP_BY_LENGTH}
|
group_by_length: ${GROUP_BY_LENGTH}
|
||||||
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
|
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -29,15 +29,15 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
- 2025/07:
|
- 2025/07:
|
||||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||||
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
|
|
||||||
@@ -46,8 +46,8 @@
|
|||||||
<summary>Expand older updates</summary>
|
<summary>Expand older updates</summary>
|
||||||
|
|
||||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||||
@@ -77,7 +77,7 @@ Features:
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python 3.11
|
||||||
- PyTorch ≥2.8.0
|
- PyTorch ≥2.7.1
|
||||||
|
|
||||||
### Google Colab
|
### Google Colab
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ Features:
|
|||||||
#### Using pip
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
|||||||
45
_quarto.yml
45
_quarto.yml
@@ -1,8 +1,6 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
pre-render:
|
pre-render: docs/scripts/generate_config_docs.py
|
||||||
- docs/scripts/generate_config_docs.py
|
|
||||||
- docs/scripts/generate_examples_docs.py
|
|
||||||
|
|
||||||
quartodoc:
|
quartodoc:
|
||||||
dir: docs/api
|
dir: docs/api
|
||||||
@@ -242,46 +240,6 @@ website:
|
|||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- section: "Model Guides"
|
|
||||||
contents:
|
|
||||||
- docs/models/kimi-linear.qmd
|
|
||||||
- docs/models/plano.qmd
|
|
||||||
- docs/models/mimo.qmd
|
|
||||||
- docs/models/internvl3_5.qmd
|
|
||||||
- docs/models/olmo3.qmd
|
|
||||||
- docs/models/trinity.qmd
|
|
||||||
- docs/models/arcee.qmd
|
|
||||||
- section: "Ministral3"
|
|
||||||
contents:
|
|
||||||
- docs/models/ministral3.qmd
|
|
||||||
- docs/models/ministral3/think.qmd
|
|
||||||
- docs/models/ministral3/vision.qmd
|
|
||||||
- section: "Magistral"
|
|
||||||
contents:
|
|
||||||
- docs/models/magistral.qmd
|
|
||||||
- docs/models/magistral/think.qmd
|
|
||||||
- docs/models/magistral/vision.qmd
|
|
||||||
- docs/models/ministral.qmd
|
|
||||||
- docs/models/mistral-small.qmd
|
|
||||||
- docs/models/voxtral.qmd
|
|
||||||
- docs/models/devstral.qmd
|
|
||||||
- docs/models/mistral.qmd
|
|
||||||
- docs/models/llama-4.qmd
|
|
||||||
- docs/models/llama-2.qmd
|
|
||||||
- docs/models/qwen3-next.qmd
|
|
||||||
- docs/models/qwen3.qmd
|
|
||||||
- docs/models/gemma3n.qmd
|
|
||||||
- docs/models/apertus.qmd
|
|
||||||
- docs/models/gpt-oss.qmd
|
|
||||||
- docs/models/seed-oss.qmd
|
|
||||||
- docs/models/phi.qmd
|
|
||||||
- docs/models/smolvlm2.qmd
|
|
||||||
- docs/models/granite4.qmd
|
|
||||||
- docs/models/LiquidAI.qmd
|
|
||||||
- docs/models/hunyuan.qmd
|
|
||||||
- docs/models/jamba.qmd
|
|
||||||
- docs/models/orpheus.qmd
|
|
||||||
|
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/telemetry.qmd
|
- docs/telemetry.qmd
|
||||||
- docs/config-reference.qmd
|
- docs/config-reference.qmd
|
||||||
@@ -320,7 +278,6 @@ website:
|
|||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
- docs/mixed_precision.qmd
|
||||||
- docs/optimizers.qmd
|
- docs/optimizers.qmd
|
||||||
- docs/attention.qmd
|
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN uv pip install torchvision
|
RUN uv pip install torchvision
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|||||||
template_env = jinja2.Environment(
|
template_env = jinja2.Environment(
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
)
|
)
|
||||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
df_template = template_env.get_template("Dockerfile.jinja")
|
||||||
df_template = template_env.get_template(dockerfile)
|
|
||||||
|
|
||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
@@ -28,11 +27,8 @@ df_args = {
|
|||||||
"CUDA": os.environ.get("CUDA", "126"),
|
"CUDA": os.environ.get("CUDA", "126"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
|
||||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||||
pytest -v --durations=10 -n2 --maxfail=3 \
|
pytest -v --durations=10 -n2 \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
|
|||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG TARGETARCH
|
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
python scripts/unsloth_install.py | sh && \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
else \
|
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
fi && \ python scripts/unsloth_install.py | sh && \
|
|
||||||
python scripts/cutcrossentropy_install.py | sh && \
|
python scripts/cutcrossentropy_install.py | sh && \
|
||||||
pip install pytest && \
|
pip install pytest && \
|
||||||
pip cache purge
|
pip cache purge
|
||||||
|
|||||||
@@ -2,16 +2,14 @@ ARG CUDA_VERSION="11.8.0"
|
|||||||
ARG CUDNN_VERSION="8"
|
ARG CUDNN_VERSION="8"
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
ARG TARGETARCH
|
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG TARGETARCH
|
ARG PYTHON_VERSION="3.10"
|
||||||
ARG PYTHON_VERSION="3.11"
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG CUDA="128"
|
ARG CUDA="118"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
@@ -24,17 +22,11 @@ RUN apt-get update \
|
|||||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||||
&& rm -rf /var/cache/apt/archives \
|
&& rm -rf /var/cache/apt/archives \
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
&& wget \
|
||||||
MINICONDA_ARCH="x86_64"; \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
|
||||||
MINICONDA_ARCH="aarch64"; \
|
|
||||||
else \
|
|
||||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
|
||||||
fi \
|
|
||||||
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
@@ -43,7 +35,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip cache purge
|
python3 -m pip cache purge
|
||||||
|
|
||||||
@@ -59,18 +51,8 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
# Map architecture
|
fi
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
ARG BASE_TAG=main
|
|
||||||
FROM axolotlai/axolotl-uv:$BASE_TAG
|
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
|
||||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
|
||||||
|
|
||||||
EXPOSE 8888
|
|
||||||
EXPOSE 22
|
|
||||||
|
|
||||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
|
||||||
COPY scripts/motd /etc/motd
|
|
||||||
|
|
||||||
RUN uv pip install jupyterlab notebook ipywidgets && \
|
|
||||||
jupyter lab clean
|
|
||||||
RUN apt update && \
|
|
||||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
|
||||||
rm -rf /var/cache/apt/archives && \
|
|
||||||
rm -rf /var/lib/apt/lists/* && \
|
|
||||||
mkdir -p ~/.ssh && \
|
|
||||||
chmod 700 ~/.ssh && \
|
|
||||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
|
||||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
|
||||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
|
||||||
chmod +x /root/cloud-entrypoint.sh && \
|
|
||||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
|
||||||
|
|
||||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
|
||||||
CMD ["sleep", "infinity"]
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
ARG BASE_TAG=main-base
|
|
||||||
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|
||||||
ARG AXOLOTL_EXTRAS=""
|
|
||||||
ARG AXOLOTL_ARGS=""
|
|
||||||
ARG CUDA="118"
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
|
||||||
ARG TARGETARCH
|
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
|
||||||
rm -rf /var/cache/apt/archives && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
|
||||||
else \
|
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
|
||||||
fi && \
|
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
else \
|
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
fi && \
|
|
||||||
python scripts/unsloth_install.py --uv | sh && \
|
|
||||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
|
||||||
uv pip install pytest && \
|
|
||||||
uv cache clean
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works with shallow clone
|
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
|
||||||
git config --get remote.origin.fetch && \
|
|
||||||
git config --global credential.helper store
|
|
||||||
|
|
||||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
|
||||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
|
||||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
|
||||||
@@ -2,11 +2,9 @@ ARG CUDA_VERSION="12.6.3"
|
|||||||
ARG CUDNN_VERSION=""
|
ARG CUDNN_VERSION=""
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
ARG TARGETARCH
|
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
ARG TARGETARCH
|
|
||||||
ARG PYTHON_VERSION="3.11"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.6.0"
|
ARG PYTORCH_VERSION="2.6.0"
|
||||||
ARG CUDA="126"
|
ARG CUDA="126"
|
||||||
@@ -33,25 +31,12 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
|||||||
|
|
||||||
RUN uv pip install packaging setuptools wheel psutil \
|
RUN uv pip install packaging setuptools wheel psutil \
|
||||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||||
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||||
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -3,5 +3,3 @@ _site/
|
|||||||
/api/*.qmd
|
/api/*.qmd
|
||||||
/api/*.html
|
/api/*.html
|
||||||
config-reference.qmd
|
config-reference.qmd
|
||||||
models/**/*.qmd
|
|
||||||
models/**/*.html
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
|
|||||||
Download a base model using the Hugging Face CLI:
|
Download a base model using the Hugging Face CLI:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||||
```
|
```
|
||||||
|
|
||||||
### 10. Create Axolotl Configuration
|
### 10. Create Axolotl Configuration
|
||||||
|
|||||||
@@ -1,140 +0,0 @@
|
|||||||
---
|
|
||||||
title: Attention
|
|
||||||
description: Supported attention modules in Axolotl
|
|
||||||
---
|
|
||||||
|
|
||||||
## SDP Attention
|
|
||||||
|
|
||||||
This is the default built-in attention in PyTorch.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
sdp_attention: true
|
|
||||||
```
|
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
|
||||||
|
|
||||||
## Flash Attention 2
|
|
||||||
|
|
||||||
Uses efficient kernels to compute attention.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
flash_attention: true
|
|
||||||
```
|
|
||||||
|
|
||||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
|
||||||
|
|
||||||
### Nvidia
|
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs
|
|
||||||
|
|
||||||
Note: For Turing GPUs or lower, please use other attention methods.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install flash-attn --no-build-isolation
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
#### Flash Attention 3
|
|
||||||
|
|
||||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
||||||
cd flash-attention/hopper
|
|
||||||
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
### AMD
|
|
||||||
|
|
||||||
Requirements: ROCm 6.0 and above.
|
|
||||||
|
|
||||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
|
||||||
|
|
||||||
## Flex Attention
|
|
||||||
|
|
||||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
flex_attention: true
|
|
||||||
|
|
||||||
# recommended
|
|
||||||
torch_compile: true
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
|
|
||||||
We recommend using latest stable version of PyTorch for best performance.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
|
||||||
|
|
||||||
## SageAttention
|
|
||||||
|
|
||||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
sage_attention: true
|
|
||||||
```
|
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install sageattention==2.2.0 --no-build-isolation
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
|
|
||||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
|
|
||||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
## xFormers
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
xformers_attention: true
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
We recommend using with Turing GPUs or below (such as on Colab).
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
|
||||||
|
|
||||||
## Shifted Sparse Attention
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
|
|
||||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
Requirements: LLaMA model architecture
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
flash_attention: true
|
|
||||||
s2_attention: true
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
No sample packing support!
|
|
||||||
|
|
||||||
:::
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Checkpoint Saving"
|
|
||||||
format:
|
|
||||||
html:
|
|
||||||
toc: true
|
|
||||||
toc-depth: 2
|
|
||||||
number-sections: true
|
|
||||||
execute:
|
|
||||||
enabled: false
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
|
||||||
|
|
||||||
## File-Based Checkpoint Trigger
|
|
||||||
|
|
||||||
### Configuration
|
|
||||||
|
|
||||||
Enable in your config:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
dynamic_checkpoint:
|
|
||||||
enabled: true
|
|
||||||
check_interval: 100 # Optional: check every N steps (default: 100)
|
|
||||||
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
|
||||||
```
|
|
||||||
|
|
||||||
**Options:**
|
|
||||||
- `enabled`: `true` to enable (required)
|
|
||||||
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
|
||||||
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
|
||||||
|
|
||||||
### How It Works
|
|
||||||
|
|
||||||
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
|
||||||
2. When detected, file is deleted and checkpoint is saved
|
|
||||||
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
**Command line:**
|
|
||||||
```bash
|
|
||||||
touch /path/to/output_dir/axolotl_checkpoint.save
|
|
||||||
```
|
|
||||||
|
|
||||||
**Programmatic:**
|
|
||||||
```python
|
|
||||||
from pathlib import Path
|
|
||||||
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
|
||||||
```
|
|
||||||
|
|
||||||
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
|
||||||
|
|
||||||
**Custom filename:**
|
|
||||||
```yaml
|
|
||||||
dynamic_checkpoint:
|
|
||||||
enabled: true
|
|
||||||
trigger_file_path: "my_trigger.save"
|
|
||||||
```
|
|
||||||
```bash
|
|
||||||
touch /path/to/output_dir/my_trigger.save
|
|
||||||
```
|
|
||||||
|
|
||||||
## Control+C (SIGINT) Checkpoint
|
|
||||||
|
|
||||||
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
|
||||||
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
|
||||||
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
save_steps: 500 # Scheduled checkpoints
|
|
||||||
|
|
||||||
dynamic_checkpoint:
|
|
||||||
enabled: true
|
|
||||||
check_interval: 50
|
|
||||||
```
|
|
||||||
|
|
||||||
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
|
||||||
@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
|
|||||||
Configuration options:
|
Configuration options:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
lm_eval_model: # model to evaluate (local or hf path)
|
|
||||||
|
|
||||||
# List of tasks to evaluate
|
# List of tasks to evaluate
|
||||||
lm_eval_tasks:
|
lm_eval_tasks:
|
||||||
- arc_challenge
|
- arc_challenge
|
||||||
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
|
|||||||
output_dir: # Directory to save evaluation results
|
output_dir: # Directory to save evaluation results
|
||||||
```
|
```
|
||||||
|
|
||||||
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
|
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||||
|
|
||||||
### delinearize-llama4
|
### delinearize-llama4
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-base-py3.11-cu128-2.8.0`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu128-2.9.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
|
- `main-base-py3.11-cu126-2.7.0`
|
||||||
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -71,12 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu128-2.8.0`
|
- `main-py3.11-cu128-2.7.1`
|
||||||
- `main-py3.11-cu128-2.9.1`
|
- `main-py3.11-cu126-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.7.0`
|
||||||
|
- `main-py3.11-cu126-2.6.0`
|
||||||
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu126-2.6.0`
|
- `main-20250303-py3.11-cu126-2.6.0`
|
||||||
- `0.12.0`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
hf auth login
|
huggingface-cli login
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting {#sec-troubleshooting}
|
## Troubleshooting {#sec-troubleshooting}
|
||||||
|
|||||||
@@ -89,10 +89,6 @@ lora_o_kernel: true
|
|||||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
LoRA kernels do not support remote modeling code.
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||||
|
|||||||
@@ -19,10 +19,8 @@ format:
|
|||||||
- [Gemma-3n](#sec-gemma-3n)
|
- [Gemma-3n](#sec-gemma-3n)
|
||||||
- [Qwen2-VL](#sec-qwen2-vl)
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
- [GLM-4.6V](#sec-glm-4-6v)
|
|
||||||
- [SmolVLM2](#sec-smolvlm2)
|
- [SmolVLM2](#sec-smolvlm2)
|
||||||
- [LFM2-VL](#sec-lfm2-vl)
|
- [LFM2-VL](#sec-lfm2-vl)
|
||||||
- [Intern-VL](#sec-intern-vl)
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -184,18 +182,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
|||||||
chat_template: qwen2_vl # same as qwen2-vl
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
```
|
```
|
||||||
|
|
||||||
### GLM-4.6V {#sec-glm-4-6v}
|
|
||||||
|
|
||||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# GLM-4.6V (106B MoE version)
|
|
||||||
base_model: zai-org/GLM-4.6V
|
|
||||||
|
|
||||||
# OR GLM-4.6V-Flash (9B version)
|
|
||||||
base_model: zai-org/GLM-4.6V-Flash
|
|
||||||
```
|
|
||||||
|
|
||||||
### SmolVLM2 {#sec-smolvlm2}
|
### SmolVLM2 {#sec-smolvlm2}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
@@ -216,16 +202,6 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
|||||||
base_model: LiquidAI/LFM2-VL-450M
|
base_model: LiquidAI/LFM2-VL-450M
|
||||||
```
|
```
|
||||||
|
|
||||||
### Intern-VL {#sec-intern-vl}
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Please make sure to install `timm` via `pip3 install timm==1.0.19`
|
|
||||||
:::
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: OpenGVLab/InternVL3_5-8B
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Format
|
## Dataset Format
|
||||||
|
|
||||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
|
|||||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||||
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
|
||||||
|
|
||||||
|
|
||||||
## RLHF using Axolotl
|
## RLHF using Axolotl
|
||||||
@@ -721,102 +720,6 @@ trl:
|
|||||||
|
|
||||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||||
|
|
||||||
### GDPO
|
|
||||||
|
|
||||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
|
|
||||||
:::
|
|
||||||
|
|
||||||
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
|
|
||||||
|
|
||||||
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
||||||
|
|
||||||
vllm:
|
|
||||||
host: 0.0.0.0
|
|
||||||
port: 8000
|
|
||||||
tensor_parallel_size: 2
|
|
||||||
gpu_memory_utilization: 0.85
|
|
||||||
|
|
||||||
rl: gdpo
|
|
||||||
|
|
||||||
trl:
|
|
||||||
beta: 0.001
|
|
||||||
max_completion_length: 256
|
|
||||||
use_vllm: true
|
|
||||||
num_generations: 4
|
|
||||||
reward_funcs:
|
|
||||||
- rewards.format_reward
|
|
||||||
- rewards.correctness_reward
|
|
||||||
reward_weights: [1.0, 2.0]
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: openai/gsm8k
|
|
||||||
name: main
|
|
||||||
type: rewards.oai_gsm8k_transform
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also use GRPO with explicit aggregation control:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
rl: grpo
|
|
||||||
trl:
|
|
||||||
multi_objective_aggregation: normalize_then_sum # GDPO behavior
|
|
||||||
# or: sum_then_normalize # Default GRPO behavior
|
|
||||||
```
|
|
||||||
|
|
||||||
#### GDPO vs GRPO
|
|
||||||
|
|
||||||
| Aspect | GRPO | GDPO |
|
|
||||||
|--------|------|------|
|
|
||||||
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
|
|
||||||
| **Multi-reward** | May collapse advantages | Preserves reward signals |
|
|
||||||
| **Single reward** | Standard behavior | Equivalent to GRPO |
|
|
||||||
|
|
||||||
#### Why GDPO?
|
|
||||||
|
|
||||||
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
|
|
||||||
|
|
||||||
```
|
|
||||||
# Example: format + correctness rewards
|
|
||||||
[format=0, correct=3] → sum=3
|
|
||||||
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
|
|
||||||
[format=2, correct=1] → sum=3
|
|
||||||
[format=3, correct=0] → sum=3
|
|
||||||
```
|
|
||||||
|
|
||||||
GDPO normalizes each reward independently, preserving their relative differences.
|
|
||||||
|
|
||||||
#### Reward Functions
|
|
||||||
|
|
||||||
GDPO uses the same reward function format as GRPO:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# rewards.py
|
|
||||||
def format_reward(completions, **kwargs) -> list[float]:
|
|
||||||
return [1.0 if len(c) > 10 else 0.0 for c in completions]
|
|
||||||
|
|
||||||
def correctness_reward(completions, answers, **kwargs) -> list[float]:
|
|
||||||
rewards = []
|
|
||||||
for completion, answer in zip(completions, answers):
|
|
||||||
# Your scoring logic here
|
|
||||||
rewards.append(score)
|
|
||||||
return rewards
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Sequence Parallelism
|
|
||||||
|
|
||||||
GDPO supports sequence parallelism for long-context training:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
rl: gdpo
|
|
||||||
context_parallel_size: 2
|
|
||||||
```
|
|
||||||
|
|
||||||
### SimPO
|
### SimPO
|
||||||
|
|
||||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
examples:
|
|
||||||
# December 2025
|
|
||||||
- name: kimi-linear
|
|
||||||
title: Kimi Linear
|
|
||||||
- name: plano
|
|
||||||
title: Plano Orchestrator
|
|
||||||
- name: mimo
|
|
||||||
title: MiMo
|
|
||||||
- name: internvl3_5
|
|
||||||
title: InternVL 3.5
|
|
||||||
|
|
||||||
# AllenAI
|
|
||||||
- name: olmo3
|
|
||||||
title: OLMo 3
|
|
||||||
|
|
||||||
# ArceeAI
|
|
||||||
- name: trinity
|
|
||||||
title: Trinity
|
|
||||||
- name: arcee
|
|
||||||
title: Arcee AFM
|
|
||||||
|
|
||||||
# MistralAI
|
|
||||||
- name: ministral3/think
|
|
||||||
title: Ministral 3 Thinking
|
|
||||||
- name: ministral3/vision
|
|
||||||
title: Ministral 3 Vision
|
|
||||||
- name: magistral/think
|
|
||||||
title: Magistral Thinking
|
|
||||||
- name: magistral/vision
|
|
||||||
title: Magistral Vision
|
|
||||||
- name: ministral
|
|
||||||
title: Ministral
|
|
||||||
- name: mistral-small
|
|
||||||
title: Mistral Small 3.1/3.2
|
|
||||||
- name: voxtral
|
|
||||||
title: Voxtral
|
|
||||||
- name: devstral
|
|
||||||
title: Devstral
|
|
||||||
- name: mistral
|
|
||||||
title: Mistral 7B
|
|
||||||
|
|
||||||
# Meta
|
|
||||||
- name: llama-4
|
|
||||||
title: Llama 4
|
|
||||||
- name: llama-2
|
|
||||||
title: Llama 2
|
|
||||||
|
|
||||||
# Alibaba
|
|
||||||
- name: qwen3-next
|
|
||||||
title: Qwen 3 Next
|
|
||||||
- name: qwen3
|
|
||||||
title: Qwen 3
|
|
||||||
|
|
||||||
# Google
|
|
||||||
- name: gemma3n
|
|
||||||
title: Gemma 3n
|
|
||||||
|
|
||||||
# Swiss AI
|
|
||||||
- name: apertus
|
|
||||||
title: Apertus
|
|
||||||
|
|
||||||
# GPT-OSS
|
|
||||||
- name: gpt-oss
|
|
||||||
title: GPT-OSS
|
|
||||||
- name: seed-oss
|
|
||||||
title: Seed-OSS
|
|
||||||
|
|
||||||
# Microsoft
|
|
||||||
- name: phi
|
|
||||||
title: Phi
|
|
||||||
|
|
||||||
# SmolVLM
|
|
||||||
- name: smolvlm2
|
|
||||||
title: SmolVLM 2
|
|
||||||
|
|
||||||
# IBM
|
|
||||||
- name: granite4
|
|
||||||
title: Granite 4
|
|
||||||
|
|
||||||
# LiquidAI
|
|
||||||
- name: LiquidAI
|
|
||||||
title: Liquid Foundation Models 2
|
|
||||||
|
|
||||||
# Other
|
|
||||||
- name: hunyuan
|
|
||||||
title: Hunyuan
|
|
||||||
- name: jamba
|
|
||||||
title: Jamba
|
|
||||||
- name: orpheus
|
|
||||||
title: Orpheus
|
|
||||||
@@ -1,424 +0,0 @@
|
|||||||
"""
|
|
||||||
auto generate example docs from allowlist
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
# Paths
|
|
||||||
THIS = Path(__file__).resolve()
|
|
||||||
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
|
|
||||||
EXAMPLES_DIR = ROOT / "examples"
|
|
||||||
OUTPUT_DIR = ROOT / "docs" / "models"
|
|
||||||
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
|
|
||||||
|
|
||||||
|
|
||||||
def slugify(name: str) -> str:
|
|
||||||
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
|
|
||||||
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
|
|
||||||
s = re.sub(r"\s+", "-", s).strip("-").lower()
|
|
||||||
return s or "example"
|
|
||||||
|
|
||||||
|
|
||||||
def read_allowlist():
|
|
||||||
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
|
|
||||||
data = yaml.safe_load(f) or {}
|
|
||||||
items = data.get("examples", [])
|
|
||||||
if not isinstance(items, list):
|
|
||||||
raise ValueError("`examples` must be a list in examples-allowlist.yml")
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def find_readme(folder: Path) -> Path | None:
|
|
||||||
for name in ("README.md", "Readme.md", "readme.md"):
|
|
||||||
p = folder / name
|
|
||||||
if p.exists():
|
|
||||||
return p
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def remove_first_h1(md: str) -> tuple[str, str | None]:
|
|
||||||
"""
|
|
||||||
Remove the first H1 from markdown and return (modified_md, h1_title).
|
|
||||||
The H1 is removed since we use the frontmatter title instead.
|
|
||||||
"""
|
|
||||||
lines = md.splitlines()
|
|
||||||
result = []
|
|
||||||
h1_title = None
|
|
||||||
skipped_first = False
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
if not skipped_first and line.startswith("# "):
|
|
||||||
h1_title = line[2:].strip()
|
|
||||||
skipped_first = True
|
|
||||||
continue
|
|
||||||
result.append(line)
|
|
||||||
|
|
||||||
return "\n".join(result), h1_title
|
|
||||||
|
|
||||||
|
|
||||||
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
|
|
||||||
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
|
||||||
|
|
||||||
|
|
||||||
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
|
|
||||||
"""
|
|
||||||
Copy local image assets referenced in markdown to
|
|
||||||
docs/examples/assets/... and rewrite the links.
|
|
||||||
"""
|
|
||||||
dest_assets = dest_assets_root / "assets"
|
|
||||||
|
|
||||||
def repl(m):
|
|
||||||
url = m.group(1).strip()
|
|
||||||
if re.match(r"^(https?:)?//", url):
|
|
||||||
return m.group(0) # leave remote URLs
|
|
||||||
src_path = (src_dir / url).resolve()
|
|
||||||
if not src_path.exists():
|
|
||||||
return m.group(0) # leave as-is if not found
|
|
||||||
rel = src_path.relative_to(src_dir)
|
|
||||||
# Create a unique asset path based on source directory name
|
|
||||||
asset_name = src_dir.name.replace("/", "-")
|
|
||||||
dest_path = dest_assets / asset_name / rel
|
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
shutil.copy2(src_path, dest_path)
|
|
||||||
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
|
|
||||||
return m.group(0).replace(url, new_rel)
|
|
||||||
|
|
||||||
return IMG_RE.sub(repl, md)
|
|
||||||
|
|
||||||
|
|
||||||
def rewrite_readme_links(
|
|
||||||
md: str,
|
|
||||||
src_dir: Path,
|
|
||||||
examples_dir: Path,
|
|
||||||
parent_index_only: set,
|
|
||||||
current_src_path: str,
|
|
||||||
allowlist_entries: set,
|
|
||||||
current_output_path: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Rewrite links between README.md files to point to the correct .qmd files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def repl(m):
|
|
||||||
text = m.group(1)
|
|
||||||
url = m.group(2).strip()
|
|
||||||
|
|
||||||
# Skip remote URLs and anchor links
|
|
||||||
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
|
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
# Skip non-markdown files
|
|
||||||
if not url.lower().endswith(".md"):
|
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
# Resolve the target path
|
|
||||||
try:
|
|
||||||
target_path = (src_dir / url).resolve()
|
|
||||||
|
|
||||||
# Check if target is outside examples_dir
|
|
||||||
try:
|
|
||||||
rel_path = target_path.relative_to(examples_dir)
|
|
||||||
except ValueError:
|
|
||||||
# Target is outside examples_dir, leave as-is
|
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
parts = list(rel_path.parts)
|
|
||||||
|
|
||||||
# Determine the output path for the target
|
|
||||||
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
|
|
||||||
# This is a README link
|
|
||||||
if len(parts) == 1:
|
|
||||||
# Link to root README -> index.qmd
|
|
||||||
target_output = "index.qmd"
|
|
||||||
elif len(parts) == 2:
|
|
||||||
if parts[0] == ".":
|
|
||||||
# Current directory README
|
|
||||||
target_output = "index.qmd"
|
|
||||||
else:
|
|
||||||
# subdir/README.md
|
|
||||||
parent_dir = parts[0]
|
|
||||||
if parent_dir in parent_index_only:
|
|
||||||
target_output = f"{parent_dir}/index.qmd"
|
|
||||||
else:
|
|
||||||
target_output = f"{parent_dir}.qmd"
|
|
||||||
else:
|
|
||||||
# Deeper nesting: parent/subdir/README.md
|
|
||||||
# Build the full path like "parent/subdir"
|
|
||||||
full_path = "/".join(parts[:-1]) # Remove README.md
|
|
||||||
# Check if this exact path is in allowlist
|
|
||||||
if full_path in allowlist_entries:
|
|
||||||
# This is a sub-entry with its own entry -> use .qmd
|
|
||||||
target_output = f"{full_path}.qmd"
|
|
||||||
elif parts[0] == ".":
|
|
||||||
# ./subdir/README.md -> check if subdir has own entry
|
|
||||||
subdir = parts[1]
|
|
||||||
if subdir in parent_index_only:
|
|
||||||
target_output = f"{subdir}/index.qmd"
|
|
||||||
else:
|
|
||||||
target_output = f"{subdir}.qmd"
|
|
||||||
else:
|
|
||||||
# parent/subdir where parent doesn't have own entry
|
|
||||||
target_output = f"{full_path}/index.qmd"
|
|
||||||
else:
|
|
||||||
# Regular .md file -> convert to .qmd, keep path structure
|
|
||||||
target_output = "/".join(parts)[:-2] + "qmd"
|
|
||||||
|
|
||||||
# Compute relative path from current output file to target
|
|
||||||
current_parts = current_output_path.split("/")
|
|
||||||
target_parts = target_output.split("/")
|
|
||||||
|
|
||||||
# Special case: if current is a subdir file and target is a single-component file at root
|
|
||||||
# Example: current="magistral/vision", target="magistral.qmd"
|
|
||||||
if len(current_parts) > 1 and len(target_parts) == 1:
|
|
||||||
# Current is in subdir, target is at root level
|
|
||||||
# Go up to root: ../ for each level
|
|
||||||
up_count = len(current_parts) - 1
|
|
||||||
rel_parts = [".."] * up_count + [target_parts[0]]
|
|
||||||
new_url = "/".join(rel_parts)
|
|
||||||
else:
|
|
||||||
# Find common prefix
|
|
||||||
i = 0
|
|
||||||
while (
|
|
||||||
i < min(len(current_parts) - 1, len(target_parts))
|
|
||||||
and current_parts[i] == target_parts[i]
|
|
||||||
):
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# Build relative path: go up (../) then down to target
|
|
||||||
up_count = len(current_parts) - 1 - i
|
|
||||||
rel_parts = [".."] * up_count + target_parts[i:]
|
|
||||||
|
|
||||||
if not rel_parts or rel_parts == [".."]:
|
|
||||||
# Points to same directory or parent
|
|
||||||
new_url = "/".join(rel_parts) if rel_parts else "."
|
|
||||||
else:
|
|
||||||
new_url = "/".join(rel_parts)
|
|
||||||
|
|
||||||
return f"[{text}]({new_url})"
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
return m.group(0)
|
|
||||||
|
|
||||||
return LINK_RE.sub(repl, md)
|
|
||||||
|
|
||||||
|
|
||||||
def write_qmd(out_path: Path, title: str, body_md: str):
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
|
||||||
out_path.write_text(fm + body_md, encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def update_quarto_yml(generated: list[tuple[str, str, str]]):
|
|
||||||
"""
|
|
||||||
Update _quarto.yml with the generated example files in the correct order.
|
|
||||||
This keeps the sidebar in sync with the allowlist.
|
|
||||||
|
|
||||||
Model Guides is now nested under "Getting Started" section.
|
|
||||||
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
|
|
||||||
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
|
|
||||||
"""
|
|
||||||
quarto_yml = ROOT / "_quarto.yml"
|
|
||||||
if not quarto_yml.exists():
|
|
||||||
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
content = quarto_yml.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# First pass: find all parents that have sub-entries
|
|
||||||
parents_with_subs = set()
|
|
||||||
for path, _name, _title in generated:
|
|
||||||
if "/" in path:
|
|
||||||
parent = path.split("/")[0]
|
|
||||||
parents_with_subs.add(parent)
|
|
||||||
|
|
||||||
# Build the YAML contents while preserving allowlist order
|
|
||||||
lines = []
|
|
||||||
processed_sections = set()
|
|
||||||
|
|
||||||
for path, _name, title in generated:
|
|
||||||
# Check if this is a parent page that has sub-pages
|
|
||||||
if path in parents_with_subs:
|
|
||||||
# This is a parent page with sub-pages - create a nested section
|
|
||||||
if path not in processed_sections:
|
|
||||||
processed_sections.add(path)
|
|
||||||
section_title = (
|
|
||||||
title or path.replace("-", " ").replace("_", " ").title()
|
|
||||||
)
|
|
||||||
lines.append(f' - section: "{section_title}"')
|
|
||||||
lines.append(" contents:")
|
|
||||||
# Add the parent page first
|
|
||||||
lines.append(f" - docs/models/{path}.qmd")
|
|
||||||
# Then add all sub-pages
|
|
||||||
for sub_path, _sub_name, _sub_title in generated:
|
|
||||||
if "/" in sub_path and sub_path.split("/")[0] == path:
|
|
||||||
lines.append(
|
|
||||||
f" - docs/models/{sub_path}.qmd"
|
|
||||||
)
|
|
||||||
elif "/" not in path:
|
|
||||||
# This is a flat item with no sub-pages
|
|
||||||
# Skip if it was already included as part of a parent section
|
|
||||||
if path not in processed_sections:
|
|
||||||
lines.append(f" - docs/models/{path}.qmd")
|
|
||||||
|
|
||||||
yaml_content = "\n".join(lines) + "\n"
|
|
||||||
|
|
||||||
# Pattern to match only the Model Guides contents, stopping at the next item
|
|
||||||
# in Getting Started (lines starting with 12 spaces: same level as the section)
|
|
||||||
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
|
|
||||||
|
|
||||||
def replacement(match):
|
|
||||||
prefix = match.group(1)
|
|
||||||
return prefix + "\n" + yaml_content
|
|
||||||
|
|
||||||
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
if new_content != content:
|
|
||||||
quarto_yml.write_text(new_content, encoding="utf-8")
|
|
||||||
print(f"Updated {quarto_yml}")
|
|
||||||
else:
|
|
||||||
print(f"No changes needed for {quarto_yml}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
allow = read_allowlist()
|
|
||||||
if not EXAMPLES_DIR.exists():
|
|
||||||
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# First pass: identify which parents have their own entry vs only sub-entries
|
|
||||||
parent_entries = set() # Parents that have their own entry
|
|
||||||
parent_with_subs = set() # Parents that have sub-entries
|
|
||||||
allowlist_entries = set() # All entries in allowlist
|
|
||||||
|
|
||||||
for item in allow:
|
|
||||||
if isinstance(item, str):
|
|
||||||
name = item
|
|
||||||
else:
|
|
||||||
name = item.get("name")
|
|
||||||
|
|
||||||
allowlist_entries.add(name)
|
|
||||||
|
|
||||||
if "/" in name:
|
|
||||||
parent = name.split("/")[0]
|
|
||||||
parent_with_subs.add(parent)
|
|
||||||
else:
|
|
||||||
parent_entries.add(name)
|
|
||||||
|
|
||||||
# Parents with subs that DON'T have their own entry -> use index.qmd
|
|
||||||
parent_index_only = parent_with_subs - parent_entries
|
|
||||||
|
|
||||||
generated = []
|
|
||||||
seen_dirs = set() # Track which parent directories we've created index for
|
|
||||||
|
|
||||||
for item in allow:
|
|
||||||
if isinstance(item, str):
|
|
||||||
name = item
|
|
||||||
title = None
|
|
||||||
else:
|
|
||||||
name = item.get("name")
|
|
||||||
title = item.get("title")
|
|
||||||
|
|
||||||
if not name:
|
|
||||||
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
src_dir = EXAMPLES_DIR / name
|
|
||||||
if not src_dir.exists() or not src_dir.is_dir():
|
|
||||||
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
readme = find_readme(src_dir)
|
|
||||||
if not readme:
|
|
||||||
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
md = readme.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Determine output path first (needed for link rewriting)
|
|
||||||
parts = name.split("/")
|
|
||||||
if len(parts) == 1:
|
|
||||||
# Simple case: no subdirectory
|
|
||||||
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
|
|
||||||
sidebar_path = parts[0]
|
|
||||||
else:
|
|
||||||
# Has subdirectory: e.g., magistral/think
|
|
||||||
parent = parts[0]
|
|
||||||
child = "-".join(parts[1:]) # handle nested subdirs
|
|
||||||
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
|
|
||||||
sidebar_path = f"{parent}/{child}"
|
|
||||||
|
|
||||||
# Remove the first H1 (we use frontmatter title instead)
|
|
||||||
md, _ = remove_first_h1(md)
|
|
||||||
# Rewrite links between README files
|
|
||||||
md = rewrite_readme_links(
|
|
||||||
md,
|
|
||||||
src_dir,
|
|
||||||
EXAMPLES_DIR,
|
|
||||||
parent_index_only,
|
|
||||||
name,
|
|
||||||
allowlist_entries,
|
|
||||||
sidebar_path,
|
|
||||||
)
|
|
||||||
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
|
|
||||||
|
|
||||||
# Handle parent page generation for sub-entries
|
|
||||||
if len(parts) > 1:
|
|
||||||
# Has subdirectory: e.g., magistral/think
|
|
||||||
parent = parts[0]
|
|
||||||
|
|
||||||
# Create parent.qmd if not already done and parent doesn't have own entry
|
|
||||||
if parent not in seen_dirs and parent in parent_index_only:
|
|
||||||
parent_readme = find_readme(EXAMPLES_DIR / parent)
|
|
||||||
if parent_readme:
|
|
||||||
parent_md = parent_readme.read_text(encoding="utf-8")
|
|
||||||
parent_md, _ = remove_first_h1(parent_md)
|
|
||||||
parent_md = rewrite_readme_links(
|
|
||||||
parent_md,
|
|
||||||
EXAMPLES_DIR / parent,
|
|
||||||
EXAMPLES_DIR,
|
|
||||||
parent_index_only,
|
|
||||||
parent,
|
|
||||||
allowlist_entries,
|
|
||||||
parent,
|
|
||||||
)
|
|
||||||
parent_md = rewrite_and_copy_assets(
|
|
||||||
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
|
|
||||||
)
|
|
||||||
parent_title = parent.replace("-", " ").replace("_", " ").title()
|
|
||||||
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
|
|
||||||
generated.append((parent, parent, parent_title))
|
|
||||||
seen_dirs.add(parent)
|
|
||||||
|
|
||||||
if not title:
|
|
||||||
title = name.replace("/", " ").replace("-", " ").title()
|
|
||||||
|
|
||||||
write_qmd(out_path, title, md)
|
|
||||||
generated.append((sidebar_path, name, title))
|
|
||||||
|
|
||||||
# Index page - preserve allowlist order
|
|
||||||
if generated:
|
|
||||||
listing = "\n".join(
|
|
||||||
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
|
|
||||||
)
|
|
||||||
index_md = (
|
|
||||||
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
|
|
||||||
+ listing
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
index_fm = (
|
|
||||||
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
|
||||||
)
|
|
||||||
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
|
|
||||||
|
|
||||||
# Auto-update _quarto.yml to keep sidebar in sync
|
|
||||||
update_quarto_yml(generated)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ gradient_checkpointing: true
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
scaling_softmax: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
loss_watchdog_threshold: 5.0
|
||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
base_model: google/gemma-3-1b-it
|
|
||||||
|
|
||||||
model_type: Gemma3ForCausalLM
|
|
||||||
cls_model_config: Gemma3TextConfig
|
|
||||||
|
|
||||||
# gemma3 doesn't seem to play nice with ddp
|
|
||||||
ddp_find_unused_parameters: true
|
|
||||||
|
|
||||||
chat_template: gemma3
|
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/eaft-gemma-3-1b
|
|
||||||
|
|
||||||
use_eaft: true
|
|
||||||
eaft_alpha: 1.0
|
|
||||||
eaft_k: 20
|
|
||||||
|
|
||||||
sequence_len: 1024
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
adapter:
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
eval_batch_size: 1
|
|
||||||
max_steps: 1000
|
|
||||||
evaluation_strategy: "no"
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
base_model: google/gemma-3-1b-it
|
base_model: google/gemma-3-1b-it
|
||||||
|
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
cls_model_config: Gemma3TextConfig
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
|
|||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
base_model: google/gemma-3-270m-it
|
base_model: google/gemma-3-270m-it
|
||||||
|
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
cls_model_config: Gemma3TextConfig
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
|
|||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
|
|||||||
|
|
||||||
# Need to set else transformers tries to load vision too
|
# Need to set else transformers tries to load vision too
|
||||||
model_type: Gemma3ForCausalLM
|
model_type: Gemma3ForCausalLM
|
||||||
cls_model_config: Gemma3TextConfig
|
|
||||||
|
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
@@ -33,8 +32,8 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
|||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
|
||||||
|
|
||||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# QLoRA
|
|
||||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
|
||||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/qlora.yaml
|
|
||||||
|
|
||||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LoRA
|
|
||||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
|
||||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/lora.yaml
|
|
||||||
|
|
||||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Expert LoRA
|
|
||||||
|
|
||||||
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
|
|
||||||
|
|
||||||
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
lora_target_parameters:
|
|
||||||
- mlp.experts.gate_up_proj
|
|
||||||
- mlp.experts.down_proj
|
|
||||||
# - mlp.gate.weight # router, untested but should work, not normally targeted
|
|
||||||
```
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
|
|
||||||
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
|
|
||||||
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
|
|
||||||
- **lora_target_linear**: Incompatible for this model.
|
|
||||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
|
||||||
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
|
||||||
- `temperature: 1.0`
|
|
||||||
- `top_p: 0.95`
|
|
||||||
- `max_new_tokens: 131072`
|
|
||||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
|
||||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# Finetune GLM-4.6V with Axolotl
|
|
||||||
|
|
||||||
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
|
|
||||||
3. Run the fine-tuning:
|
|
||||||
|
|
||||||
glm-4-6v-flash(9B)
|
|
||||||
```bash
|
|
||||||
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
## Tips
|
|
||||||
|
|
||||||
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
|
|
||||||
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
|
|
||||||
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.6V-Flash
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
ddp_find_unused_parameters: true
|
|
||||||
|
|
||||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
logging_steps: 1
|
|
||||||
sdp_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 0
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.6V-Flash
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
output_dir: ./outputs/glm-4-6v-flash-qlora
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
logging_steps: 1
|
|
||||||
sdp_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 0
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -32,10 +32,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,10 +28,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -29,10 +29,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -28,10 +28,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,10 +41,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -41,10 +41,6 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
trackio_project_name:
|
|
||||||
trackio_run_name:
|
|
||||||
trackio_space_id:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
gradient_accumulation_steps: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
# Finetune OpenGV's InternVL with Axolotl
|
|
||||||
|
|
||||||
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install `timm` for vision model support:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install timm==1.0.19
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
4. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
base_model: OpenGVLab/InternVL3_5-8B-HF
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -19,6 +19,7 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: jamba-large-fsdp-qlora-ft
|
output_dir: jamba-large-fsdp-qlora-ft
|
||||||
|
save_safetensors: true
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
|
||||||
|
|
||||||
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 98.7GiB VRAM.
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning!
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- Kimi Linear requires `trust_remote_code: true`.
|
|
||||||
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
This is not yet compatible with MoE kernels from transformers v5.
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
|
||||||
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
split: train
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.2
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 2
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -29,6 +29,7 @@ flex_attention: true
|
|||||||
flex_attn_compile_kwargs:
|
flex_attn_compile_kwargs:
|
||||||
dynamic: false
|
dynamic: false
|
||||||
mode: max-autotune-no-cudagraphs
|
mode: max-autotune-no-cudagraphs
|
||||||
|
save_strategy: no
|
||||||
torch_compile: true
|
torch_compile: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
|
|
||||||
chat_template: llama3
|
|
||||||
|
|
||||||
rl: gdpo
|
|
||||||
|
|
||||||
trl:
|
|
||||||
beta: 0.001
|
|
||||||
max_completion_length: 128
|
|
||||||
num_generations: 2
|
|
||||||
temperature: 0.7
|
|
||||||
top_p: 0.95
|
|
||||||
|
|
||||||
use_vllm: false
|
|
||||||
|
|
||||||
|
|
||||||
multi_objective_aggregation: normalize_then_sum
|
|
||||||
|
|
||||||
reward_funcs:
|
|
||||||
- rwd.format_reward
|
|
||||||
- rwd.correctness_reward
|
|
||||||
reward_weights: [1.0, 2.0]
|
|
||||||
|
|
||||||
log_completions: true
|
|
||||||
num_completions_to_print: 3
|
|
||||||
scale_rewards: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: openai/gsm8k
|
|
||||||
name: main
|
|
||||||
split: train[:1000]
|
|
||||||
type: rwd.gsm8k_transform
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/llama3-gdpo-out
|
|
||||||
|
|
||||||
sequence_len: 512
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
max_steps: 100
|
|
||||||
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-5
|
|
||||||
weight_decay: 0.01
|
|
||||||
warmup_steps: 10
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
logging_steps: 1
|
|
||||||
save_steps: 50
|
|
||||||
save_safetensors: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: "<|end_of_text|>"
|
|
||||||
|
|
||||||
|
|
||||||
seed: 42
|
|
||||||
@@ -12,6 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out/qlora-llama3_1-405b
|
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||||
|
save_safetensors: true
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [main README](../README.md))
|
- Installed Axolotl (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||||
- Installed Axolotl from source (see [main README](../README.md))
|
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
|||||||
@@ -47,5 +47,6 @@ saves_per_epoch: 1
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
special_tokens:
|
special_tokens:
|
||||||
tokens:
|
tokens:
|
||||||
|
save_safetensors: False
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
# Finetune Xiaomi's MiMo with Axolotl
|
|
||||||
|
|
||||||
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/mimo/mimo-7b-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: XiaomiMiMo/MiMo-7B-RL
|
|
||||||
trust_remote_code: true
|
|
||||||
revision_of_model: 6299b5a
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
# CCE - N/A as of now
|
|
||||||
# plugins:
|
|
||||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -59,7 +59,6 @@ gradient_checkpointing: true
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
scaling_softmax: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [main README](../README.md))
|
- Installed Axolotl (see [main README](../README.md))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||||
- Installed Axolotl from source (see [main README](../README.md))
|
|
||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
|
|||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Before starting, ensure you have:
|
Before starting, ensure you have:
|
||||||
|
|
||||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
### TIPS
|
### TIPS
|
||||||
|
|
||||||
|
|||||||
@@ -42,10 +42,10 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
# Finetune Katanemo's Plano-Orchestrator with Axolotl
|
|
||||||
|
|
||||||
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/plano/plano-4b-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### Orchestration Prompt
|
|
||||||
|
|
||||||
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
|
|
||||||
|
|
||||||
### Tips
|
|
||||||
|
|
||||||
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [Plano GitHub](https://github.com/katanemo/plano)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: katanemo/Plano-Orchestrator-4B
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
chat_template: qwen3
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-0.5B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
# Use random initialization for fair comparison
|
|
||||||
reinit_weights: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# Pretraining dataset
|
|
||||||
pretraining_dataset:
|
|
||||||
- path: allenai/c4
|
|
||||||
name: en
|
|
||||||
type: pretrain
|
|
||||||
split: train
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/compare-adamw-pretrain
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project: dist_muon
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name: adamw
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 4
|
|
||||||
num_epochs: 1
|
|
||||||
max_steps: 305
|
|
||||||
|
|
||||||
# AdamW optimizer settings (standard LR for AdamW)
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
learning_rate: 0.0002
|
|
||||||
weight_decay: 0.01
|
|
||||||
lr_scheduler: cosine
|
|
||||||
|
|
||||||
train_on_inputs: true
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 0
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed: 42
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_offload_params: false
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_cpu_ram_efficient_loading: false
|
|
||||||
fsdp_reshard_after_forward: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-0.5B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
# Use random initialization for fair comparison
|
|
||||||
reinit_weights: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# Pretraining dataset
|
|
||||||
pretraining_dataset:
|
|
||||||
- path: allenai/c4
|
|
||||||
name: en
|
|
||||||
type: pretrain
|
|
||||||
split: train
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/compare-muon-pretrain
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project: dist_muon
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name: muon
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 4
|
|
||||||
num_epochs: 1
|
|
||||||
max_steps: 305
|
|
||||||
|
|
||||||
# Muon optimizer settings
|
|
||||||
optimizer: muon
|
|
||||||
learning_rate: 0.02
|
|
||||||
weight_decay: 0.01
|
|
||||||
lr_scheduler: cosine
|
|
||||||
|
|
||||||
train_on_inputs: true
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 0
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
seed: 42
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_offload_params: false
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_cpu_ram_efficient_loading: false
|
|
||||||
fsdp_reshard_after_forward: true
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install Qwen3-Next transformers commit
|
||||||
|
```bash
|
||||||
|
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||||
|
```
|
||||||
|
|
||||||
3. Install FLA for improved performance
|
3. Install FLA for improved performance
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the finetuning example:
|
4. Run the finetuning example:
|
||||||
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
|||||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
This config uses about 45.62 GiB VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ plugins:
|
|||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
@@ -27,7 +25,7 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 8
|
lora_alpha: 8
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- linear_attn.in_proj_ba
|
- linear_attn.in_proj_ba
|
||||||
- linear_attn.in_proj_qkvz
|
- linear_attn.in_proj_qkvz
|
||||||
@@ -36,19 +34,12 @@ lora_target_modules:
|
|||||||
- shared_expert.down_proj
|
- shared_expert.down_proj
|
||||||
- shared_expert.gate_proj
|
- shared_expert.gate_proj
|
||||||
- shared_expert_gate
|
- shared_expert_gate
|
||||||
|
- mlp.gate
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
|
|||||||
@@ -1,285 +0,0 @@
|
|||||||
# SwanLab Integration Examples
|
|
||||||
|
|
||||||
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
|
|
||||||
|
|
||||||
## Examples Overview
|
|
||||||
|
|
||||||
### 1. DPO with Completion Logging
|
|
||||||
**File**: `dpo-swanlab-completions.yml`
|
|
||||||
|
|
||||||
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
|
|
||||||
|
|
||||||
**Features**:
|
|
||||||
- Basic SwanLab experiment tracking
|
|
||||||
- Completion table logging (prompts, chosen/rejected responses, rewards)
|
|
||||||
- Memory-bounded buffer for long training runs
|
|
||||||
- Cloud sync configuration
|
|
||||||
|
|
||||||
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
|
|
||||||
|
|
||||||
**Quick start**:
|
|
||||||
```bash
|
|
||||||
export SWANLAB_API_KEY=your-api-key
|
|
||||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2. LoRA with Performance Profiling
|
|
||||||
**File**: `lora-swanlab-profiling.yml`
|
|
||||||
|
|
||||||
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
|
|
||||||
|
|
||||||
**Features**:
|
|
||||||
- SwanLab experiment tracking
|
|
||||||
- Automatic profiling of trainer methods
|
|
||||||
- Profiling metrics visualization
|
|
||||||
- Performance optimization guidance
|
|
||||||
|
|
||||||
**Best for**: Engineers optimizing training performance and comparing different configurations
|
|
||||||
|
|
||||||
**Quick start**:
|
|
||||||
```bash
|
|
||||||
export SWANLAB_API_KEY=your-api-key
|
|
||||||
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 3. Full-Featured DPO Production Setup
|
|
||||||
**File**: `dpo-swanlab-full-featured.yml`
|
|
||||||
|
|
||||||
Comprehensive production-ready configuration with ALL SwanLab features enabled.
|
|
||||||
|
|
||||||
**Features**:
|
|
||||||
- Experiment tracking with team workspace
|
|
||||||
- RLHF completion logging
|
|
||||||
- Performance profiling
|
|
||||||
- Lark (Feishu) team notifications
|
|
||||||
- Private deployment support
|
|
||||||
- Production checklist and troubleshooting
|
|
||||||
|
|
||||||
**Best for**: Production RLHF training with team collaboration
|
|
||||||
|
|
||||||
**Quick start**:
|
|
||||||
```bash
|
|
||||||
export SWANLAB_API_KEY=your-api-key
|
|
||||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
|
||||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
|
||||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 4. Custom Trainer Profiling (Python)
|
|
||||||
**File**: `custom_trainer_profiling.py`
|
|
||||||
|
|
||||||
Python code examples showing how to add SwanLab profiling to custom trainers.
|
|
||||||
|
|
||||||
**Features**:
|
|
||||||
- `@swanlab_profile` decorator examples
|
|
||||||
- Context manager profiling for fine-grained timing
|
|
||||||
- `ProfilingConfig` for advanced filtering and throttling
|
|
||||||
- Multiple profiling patterns and best practices
|
|
||||||
|
|
||||||
**Best for**: Advanced users creating custom trainers
|
|
||||||
|
|
||||||
**Usage**:
|
|
||||||
```python
|
|
||||||
from custom_trainer_profiling import CustomTrainerWithProfiling
|
|
||||||
# See file for detailed examples and patterns
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Feature Matrix
|
|
||||||
|
|
||||||
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|
|
||||||
|---------|----------|-------------------|-----------|-------------------|----------------|
|
|
||||||
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
|
||||||
| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
|
||||||
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
|
|
||||||
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Configuration Quick Reference
|
|
||||||
|
|
||||||
### Basic SwanLab Setup
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
||||||
|
|
||||||
use_swanlab: true
|
|
||||||
swanlab_project: my-project
|
|
||||||
swanlab_experiment_name: my-experiment
|
|
||||||
swanlab_mode: cloud # cloud, local, offline, disabled
|
|
||||||
```
|
|
||||||
|
|
||||||
### RLHF Completion Logging
|
|
||||||
```yaml
|
|
||||||
swanlab_log_completions: true
|
|
||||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
|
||||||
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
|
|
||||||
```
|
|
||||||
|
|
||||||
### Lark Team Notifications
|
|
||||||
```yaml
|
|
||||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
|
||||||
swanlab_lark_secret: your-webhook-secret # Required for production
|
|
||||||
```
|
|
||||||
|
|
||||||
### Team Workspace
|
|
||||||
```yaml
|
|
||||||
swanlab_workspace: my-research-team
|
|
||||||
```
|
|
||||||
|
|
||||||
### Private Deployment
|
|
||||||
```yaml
|
|
||||||
swanlab_web_host: https://swanlab.yourcompany.com
|
|
||||||
swanlab_api_host: https://api.swanlab.yourcompany.com
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Authentication
|
|
||||||
|
|
||||||
### Recommended: Environment Variable
|
|
||||||
```bash
|
|
||||||
export SWANLAB_API_KEY=your-api-key
|
|
||||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
|
||||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
|
||||||
```
|
|
||||||
|
|
||||||
### Alternative: Config File (less secure)
|
|
||||||
```yaml
|
|
||||||
swanlab_api_key: your-api-key
|
|
||||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
|
||||||
swanlab_lark_secret: your-webhook-secret
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Common Use Cases
|
|
||||||
|
|
||||||
### Use Case 1: Migrate from WandB to SwanLab
|
|
||||||
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
|
|
||||||
```yaml
|
|
||||||
use_swanlab: true
|
|
||||||
use_wandb: false
|
|
||||||
```
|
|
||||||
|
|
||||||
### Use Case 2: Analyze DPO Model Outputs
|
|
||||||
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
|
|
||||||
```yaml
|
|
||||||
swanlab_completion_log_interval: 50 # More frequent for short training
|
|
||||||
swanlab_completion_log_interval: 200 # Less frequent for long training
|
|
||||||
```
|
|
||||||
|
|
||||||
### Use Case 3: Optimize Training Performance
|
|
||||||
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
|
|
||||||
- Baseline: `flash_attention: false, gradient_checkpointing: false`
|
|
||||||
- Flash Attention: `flash_attention: true`
|
|
||||||
- Gradient Checkpointing: `gradient_checkpointing: true`
|
|
||||||
- Both: `flash_attention: true, gradient_checkpointing: true`
|
|
||||||
|
|
||||||
Compare profiling metrics in SwanLab dashboard.
|
|
||||||
|
|
||||||
### Use Case 4: Production RLHF with Team Collaboration
|
|
||||||
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
|
|
||||||
```yaml
|
|
||||||
swanlab_workspace: ml-team
|
|
||||||
swanlab_lark_webhook_url: ...
|
|
||||||
swanlab_lark_secret: ...
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Viewing Your Experiments
|
|
||||||
|
|
||||||
### Cloud Mode
|
|
||||||
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
|
|
||||||
|
|
||||||
**Dashboard sections**:
|
|
||||||
- **Metrics**: Training loss, learning rate, profiling metrics
|
|
||||||
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
|
|
||||||
- **Config**: Hyperparameters and configuration
|
|
||||||
- **System**: Resource usage (GPU, memory, CPU)
|
|
||||||
- **Files**: Logged artifacts
|
|
||||||
|
|
||||||
### Local Mode
|
|
||||||
```bash
|
|
||||||
swanlab watch ./swanlog
|
|
||||||
# Open browser to http://localhost:5092
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### SwanLab not initializing
|
|
||||||
```bash
|
|
||||||
# Check API key
|
|
||||||
echo $SWANLAB_API_KEY
|
|
||||||
|
|
||||||
# Verify SwanLab is installed
|
|
||||||
pip show swanlab
|
|
||||||
|
|
||||||
# Check config
|
|
||||||
grep -A 5 "use_swanlab" your-config.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Completions not appearing
|
|
||||||
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
|
||||||
- Check `swanlab_log_completions: true`
|
|
||||||
- Wait for `swanlab_completion_log_interval` steps
|
|
||||||
- Look for "Registered SwanLab RLHF completion logging" in logs
|
|
||||||
|
|
||||||
### Lark notifications not working
|
|
||||||
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
|
|
||||||
- Verify `SWANLAB_LARK_SECRET` is set correctly
|
|
||||||
- Check bot is added to Lark group chat
|
|
||||||
- Look for "Registered Lark notification callback" in logs
|
|
||||||
|
|
||||||
### Profiling metrics not appearing
|
|
||||||
- Verify `use_swanlab: true`
|
|
||||||
- Check SwanLab is initialized (look for init log message)
|
|
||||||
- Profiling metrics are under "profiling/" namespace
|
|
||||||
- Profiling auto-enabled when SwanLab is enabled
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Performance Notes
|
|
||||||
|
|
||||||
### Overhead Comparison
|
|
||||||
|
|
||||||
| Feature | Overhead per Step | Memory Usage |
|
|
||||||
|---------|------------------|--------------|
|
|
||||||
| Basic tracking | < 0.1% | ~10 MB |
|
|
||||||
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
|
|
||||||
| Profiling | < 0.1% | ~1 KB |
|
|
||||||
| **Total** | **< 0.7%** | **~10 MB** |
|
|
||||||
|
|
||||||
### Best Practices
|
|
||||||
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
|
|
||||||
2. Adjust completion log interval based on training length (100-200 steps)
|
|
||||||
3. Keep completion buffer size reasonable (128-512)
|
|
||||||
4. Profile critical path methods first (training_step, compute_loss)
|
|
||||||
5. Use ProfilingConfig to throttle high-frequency operations
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Further Reading
|
|
||||||
|
|
||||||
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
|
|
||||||
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
|
|
||||||
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
|
|
||||||
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Contributing
|
|
||||||
|
|
||||||
Found an issue or have an improvement? Please submit a PR or open an issue:
|
|
||||||
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
|
|
||||||
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)
|
|
||||||
@@ -1,299 +0,0 @@
|
|||||||
"""Example: Custom Trainer with SwanLab Profiling
|
|
||||||
|
|
||||||
This example demonstrates how to add SwanLab profiling to your custom trainer.
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- @swanlab_profile decorator for automatic profiling
|
|
||||||
- swanlab_profiling_context for fine-grained profiling
|
|
||||||
- ProfilingConfig for advanced filtering and throttling
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
1. Create your custom trainer extending AxolotlTrainer
|
|
||||||
2. Add @swanlab_profile decorators to methods you want to profile
|
|
||||||
3. Use swanlab_profiling_context for fine-grained profiling within methods
|
|
||||||
4. Enable SwanLab in your config (use_swanlab: true)
|
|
||||||
|
|
||||||
See also:
|
|
||||||
- examples/swanlab/lora-swanlab-profiling.yml for config
|
|
||||||
- src/axolotl/integrations/swanlab/profiling.py for implementation
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.integrations.swanlab.profiling import (
|
|
||||||
ProfilingConfig,
|
|
||||||
swanlab_profile,
|
|
||||||
swanlab_profiling_context,
|
|
||||||
swanlab_profiling_context_advanced,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainerWithProfiling(AxolotlTrainer):
|
|
||||||
"""Custom trainer with SwanLab profiling enabled.
|
|
||||||
|
|
||||||
This trainer demonstrates three profiling patterns:
|
|
||||||
1. Decorator-based profiling (@swanlab_profile)
|
|
||||||
2. Context manager profiling (swanlab_profiling_context)
|
|
||||||
3. Advanced profiling with filtering (ProfilingConfig)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
# Create custom profiling config for high-frequency operations
|
|
||||||
self.fast_op_config = ProfilingConfig(
|
|
||||||
enabled=True,
|
|
||||||
min_duration_ms=0.5, # Only log if duration > 0.5ms
|
|
||||||
log_interval=50, # Log every 50th call
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# Pattern 1: Decorator-based Profiling
|
|
||||||
# ========================================================================
|
|
||||||
# Best for: Methods you always want to profile
|
|
||||||
# Overhead: ~2-5 microseconds per call (negligible)
|
|
||||||
|
|
||||||
@swanlab_profile
|
|
||||||
def training_step(self, model, inputs):
|
|
||||||
"""Main training step - always profile.
|
|
||||||
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
|
|
||||||
"""
|
|
||||||
return super().training_step(model, inputs)
|
|
||||||
|
|
||||||
@swanlab_profile
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
|
||||||
"""Loss computation - always profile.
|
|
||||||
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
|
|
||||||
"""
|
|
||||||
return super().compute_loss(model, inputs, return_outputs)
|
|
||||||
|
|
||||||
@swanlab_profile
|
|
||||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
|
||||||
"""Prediction step - always profile.
|
|
||||||
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
|
|
||||||
"""
|
|
||||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# Pattern 2: Fine-grained Context Manager Profiling
|
|
||||||
# ========================================================================
|
|
||||||
# Best for: Profiling specific code blocks within a method
|
|
||||||
# Use case: When you want to profile forward vs backward separately
|
|
||||||
|
|
||||||
def complex_training_step(self, model, inputs):
|
|
||||||
"""Training step with fine-grained profiling.
|
|
||||||
|
|
||||||
Profiling metrics:
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
|
|
||||||
"""
|
|
||||||
# Profile just the forward pass
|
|
||||||
with swanlab_profiling_context(self, "forward_pass"):
|
|
||||||
outputs = model(**inputs)
|
|
||||||
loss = outputs.loss
|
|
||||||
|
|
||||||
# Profile just the backward pass
|
|
||||||
with swanlab_profiling_context(self, "backward_pass"):
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Profile optimizer step
|
|
||||||
with swanlab_profiling_context(self, "optimizer_step"):
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# Pattern 3: Advanced Profiling with Filtering
|
|
||||||
# ========================================================================
|
|
||||||
# Best for: High-frequency operations where you want to throttle logging
|
|
||||||
# Use case: Methods called 100+ times per step
|
|
||||||
|
|
||||||
def _prepare_inputs(self, inputs):
|
|
||||||
"""Prepare inputs - throttled profiling.
|
|
||||||
|
|
||||||
This method is called frequently (once per batch), so we throttle
|
|
||||||
profiling to reduce overhead:
|
|
||||||
- Only log if duration > 0.5ms (skip very fast operations)
|
|
||||||
- Only log every 50th call (reduce logging frequency)
|
|
||||||
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
|
|
||||||
"""
|
|
||||||
with swanlab_profiling_context_advanced(
|
|
||||||
self, "prepare_inputs", config=self.fast_op_config
|
|
||||||
):
|
|
||||||
return super()._prepare_inputs(inputs)
|
|
||||||
|
|
||||||
def _prepare_input_for_model(self, input_ids):
|
|
||||||
"""Another high-frequency operation - throttled profiling.
|
|
||||||
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
|
|
||||||
"""
|
|
||||||
with swanlab_profiling_context_advanced(
|
|
||||||
self, "prepare_input_for_model", config=self.fast_op_config
|
|
||||||
):
|
|
||||||
# Your custom input preparation logic
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
# ========================================================================
|
|
||||||
# Pattern 4: Exception-safe Profiling
|
|
||||||
# ========================================================================
|
|
||||||
# Profiling is exception-safe: duration is logged even if method raises
|
|
||||||
|
|
||||||
@swanlab_profile
|
|
||||||
def potentially_failing_method(self):
|
|
||||||
"""This method may raise an exception.
|
|
||||||
|
|
||||||
SwanLab profiling will still log the duration before re-raising.
|
|
||||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
|
|
||||||
"""
|
|
||||||
# Do some work
|
|
||||||
result = self._do_risky_computation()
|
|
||||||
|
|
||||||
# If this raises, profiling duration is still logged
|
|
||||||
if result < 0:
|
|
||||||
raise ValueError("Invalid result")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _do_risky_computation(self):
|
|
||||||
"""Placeholder for risky computation."""
|
|
||||||
return 42
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Advanced Example: Custom ProfilingConfig Per Method
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedProfilingTrainer(AxolotlTrainer):
|
|
||||||
"""Trainer with method-specific profiling configurations."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
# Different profiling configs for different method types
|
|
||||||
self.critical_path_config = ProfilingConfig(
|
|
||||||
enabled=True,
|
|
||||||
min_duration_ms=0.0, # Log everything on critical path
|
|
||||||
log_interval=1, # Log every call
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fast_path_config = ProfilingConfig(
|
|
||||||
enabled=True,
|
|
||||||
min_duration_ms=1.0, # Only log if > 1ms
|
|
||||||
log_interval=100, # Log every 100th call
|
|
||||||
)
|
|
||||||
|
|
||||||
self.debug_config = ProfilingConfig(
|
|
||||||
enabled=True,
|
|
||||||
min_duration_ms=0.0, # Log everything
|
|
||||||
log_interval=1, # Log every call
|
|
||||||
)
|
|
||||||
|
|
||||||
def training_step(self, model, inputs):
|
|
||||||
"""Critical path - log everything."""
|
|
||||||
with swanlab_profiling_context_advanced(
|
|
||||||
self, "training_step", config=self.critical_path_config
|
|
||||||
):
|
|
||||||
return super().training_step(model, inputs)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, inputs):
|
|
||||||
"""Fast path - throttle logging."""
|
|
||||||
with swanlab_profiling_context_advanced(
|
|
||||||
self, "prepare_inputs", config=self.fast_path_config
|
|
||||||
):
|
|
||||||
return super()._prepare_inputs(inputs)
|
|
||||||
|
|
||||||
def _debug_method(self, data):
|
|
||||||
"""Debug-only method - verbose logging."""
|
|
||||||
with swanlab_profiling_context_advanced(
|
|
||||||
self, "debug_method", config=self.debug_config
|
|
||||||
):
|
|
||||||
# Your debug logic
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# How to Use This Custom Trainer
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""
|
|
||||||
To use this custom trainer:
|
|
||||||
|
|
||||||
1. Save this file to your project (e.g., my_custom_trainer.py)
|
|
||||||
|
|
||||||
2. Create a config file that uses your custom trainer:
|
|
||||||
|
|
||||||
# config.yml
|
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
|
||||||
|
|
||||||
# ... other config ...
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
||||||
|
|
||||||
use_swanlab: true
|
|
||||||
swanlab_project: my-profiling-experiment
|
|
||||||
|
|
||||||
# Optional: Specify custom trainer
|
|
||||||
# (Or modify axolotl to use your custom trainer class)
|
|
||||||
|
|
||||||
3. Run training:
|
|
||||||
|
|
||||||
export SWANLAB_API_KEY=your-api-key
|
|
||||||
accelerate launch -m axolotl.cli.train config.yml
|
|
||||||
|
|
||||||
4. View profiling metrics in SwanLab dashboard:
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.training_step
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
|
||||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
|
||||||
- etc.
|
|
||||||
|
|
||||||
5. Compare profiling metrics across runs:
|
|
||||||
- Run baseline without optimizations
|
|
||||||
- Run with flash_attention enabled
|
|
||||||
- Run with gradient_checkpointing enabled
|
|
||||||
- Compare profiling metrics to see performance impact
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tips for Effective Profiling
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""
|
|
||||||
1. Profile the critical path first:
|
|
||||||
- training_step, compute_loss, prediction_step
|
|
||||||
- These methods are called most frequently and have biggest impact
|
|
||||||
|
|
||||||
2. Use throttling for high-frequency operations:
|
|
||||||
- Methods called 100+ times per step
|
|
||||||
- Use log_interval=50 or log_interval=100
|
|
||||||
- Reduces profiling overhead and dashboard clutter
|
|
||||||
|
|
||||||
3. Filter noise with min_duration_ms:
|
|
||||||
- Set min_duration_ms=1.0 to skip very fast operations
|
|
||||||
- Focus on operations that actually take time
|
|
||||||
|
|
||||||
4. Compare across runs:
|
|
||||||
- Run same config multiple times to check consistency
|
|
||||||
- Compare different optimization strategies
|
|
||||||
- Track profiling trends over time
|
|
||||||
|
|
||||||
5. Monitor distributed training:
|
|
||||||
- Check for per-rank timing differences
|
|
||||||
- Look for stragglers (slower ranks)
|
|
||||||
- Identify synchronization bottlenecks
|
|
||||||
|
|
||||||
6. Disable profiling in production:
|
|
||||||
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
|
|
||||||
- DEFAULT_PROFILING_CONFIG.enabled = False
|
|
||||||
|
|
||||||
7. Exception handling:
|
|
||||||
- Profiling is exception-safe
|
|
||||||
- Duration logged even if method raises
|
|
||||||
- Useful for debugging methods that fail intermittently
|
|
||||||
"""
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# SwanLab DPO Training Example with Completion Logging
|
|
||||||
#
|
|
||||||
# This example demonstrates DPO (Direct Preference Optimization) training
|
|
||||||
# with SwanLab integration for experiment tracking and completion table logging.
|
|
||||||
#
|
|
||||||
# Features enabled:
|
|
||||||
# - SwanLab experiment tracking
|
|
||||||
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
|
|
||||||
# - Lark (Feishu) team notifications (optional)
|
|
||||||
#
|
|
||||||
# To run:
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
|
||||||
|
|
||||||
# Model Configuration
|
|
||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot_id|>
|
|
||||||
|
|
||||||
# Quantization
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
|
|
||||||
# LoRA Configuration
|
|
||||||
adapter: lora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
# DPO Configuration
|
|
||||||
chat_template: llama3
|
|
||||||
rl: dpo
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
|
||||||
type: chat_template.default
|
|
||||||
field_messages: conversation
|
|
||||||
field_chosen: chosen
|
|
||||||
field_rejected: rejected
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
roles:
|
|
||||||
system:
|
|
||||||
- system
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
# Dataset and Output
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/dpo-swanlab-out
|
|
||||||
|
|
||||||
# Training Configuration
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: false
|
|
||||||
micro_batch_size: 2
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
num_epochs: 4
|
|
||||||
|
|
||||||
# Optimization
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
# Precision
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
# Performance
|
|
||||||
gradient_checkpointing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
# Checkpointing and Logging
|
|
||||||
logging_steps: 1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# SwanLab Integration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
||||||
|
|
||||||
# Basic SwanLab Configuration
|
|
||||||
use_swanlab: true
|
|
||||||
swanlab_project: dpo-training
|
|
||||||
swanlab_experiment_name: llama-3-dpo-completions-demo
|
|
||||||
swanlab_description: "DPO training with completion table logging"
|
|
||||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
|
||||||
|
|
||||||
# SwanLab Authentication
|
|
||||||
# Recommended: Set via environment variable
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# Or set in config (less secure):
|
|
||||||
# swanlab_api_key: your-api-key
|
|
||||||
|
|
||||||
# Optional: Team workspace
|
|
||||||
# swanlab_workspace: my-research-team
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# RLHF Completion Table Logging
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# Automatically logs model completions to SwanLab for qualitative analysis:
|
|
||||||
# - Prompts from your DPO dataset
|
|
||||||
# - Chosen responses (preferred)
|
|
||||||
# - Rejected responses (non-preferred)
|
|
||||||
# - Reward differences
|
|
||||||
#
|
|
||||||
# View the table in SwanLab dashboard under "rlhf_completions"
|
|
||||||
|
|
||||||
swanlab_log_completions: true
|
|
||||||
swanlab_completion_log_interval: 100 # Log every 100 training steps
|
|
||||||
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
|
|
||||||
|
|
||||||
# Memory Usage Notes:
|
|
||||||
# - Buffer size 128: ~64 KB (default, recommended)
|
|
||||||
# - Buffer size 512: ~256 KB (for more historical completions)
|
|
||||||
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
|
|
||||||
|
|
||||||
# Performance Notes:
|
|
||||||
# - Completion logging overhead: < 0.5% per training step
|
|
||||||
# - Only logs every N steps to minimize impact
|
|
||||||
# - Memory-bounded buffer prevents memory leaks
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Optional: Lark (Feishu) Team Notifications
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# Get real-time training notifications in your team chat
|
|
||||||
# Uncomment to enable:
|
|
||||||
|
|
||||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
|
||||||
# swanlab_lark_secret: your-webhook-secret # Recommended for production
|
|
||||||
|
|
||||||
# Notifications sent for:
|
|
||||||
# - Training start
|
|
||||||
# - Training completion
|
|
||||||
# - Training errors
|
|
||||||
# - Metric milestones (if configured)
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Optional: Private SwanLab Deployment
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# For enterprise users with private SwanLab deployment:
|
|
||||||
|
|
||||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
|
||||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Disable WandB if you're migrating from it
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# wandb_project:
|
|
||||||
# wandb_entity:
|
|
||||||
# use_wandb: false
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
# SwanLab Full-Featured DPO Training Example
|
|
||||||
#
|
|
||||||
# This example demonstrates ALL SwanLab integration features:
|
|
||||||
# - Experiment tracking with cloud sync
|
|
||||||
# - RLHF completion table logging
|
|
||||||
# - Performance profiling
|
|
||||||
# - Lark (Feishu) team notifications
|
|
||||||
# - Team workspace collaboration
|
|
||||||
#
|
|
||||||
# Use this as a reference for production RLHF training setups.
|
|
||||||
#
|
|
||||||
# To run:
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
|
||||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
|
||||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Model Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot_id|>
|
|
||||||
|
|
||||||
# Quantization for efficient training
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# LoRA Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true # Target all linear layers
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# DPO (Direct Preference Optimization) Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
chat_template: llama3
|
|
||||||
rl: dpo # Enable DPO trainer
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
|
||||||
type: chat_template.default
|
|
||||||
field_messages: conversation
|
|
||||||
field_chosen: chosen
|
|
||||||
field_rejected: rejected
|
|
||||||
message_property_mappings:
|
|
||||||
role: role
|
|
||||||
content: content
|
|
||||||
roles:
|
|
||||||
system:
|
|
||||||
- system
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Dataset and Output Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/dpo-swanlab-full-featured-out
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Training Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
micro_batch_size: 2
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
num_epochs: 4
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Optimization
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Precision and Performance
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Checkpointing and Logging
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# SwanLab Integration - Full Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Basic SwanLab Configuration
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
use_swanlab: true
|
|
||||||
swanlab_project: dpo-production
|
|
||||||
swanlab_experiment_name: llama-3-dpo-full-featured-v1
|
|
||||||
swanlab_description: |
|
|
||||||
Production DPO training with all SwanLab features enabled:
|
|
||||||
- Completion table logging for qualitative analysis
|
|
||||||
- Performance profiling for optimization
|
|
||||||
- Lark notifications for team collaboration
|
|
||||||
|
|
||||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Team Collaboration
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Workspace for team collaboration (shared experiments)
|
|
||||||
swanlab_workspace: ml-research-team
|
|
||||||
|
|
||||||
# Authentication (recommended: use environment variable)
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# Or set in config (less secure):
|
|
||||||
# swanlab_api_key: your-api-key
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# RLHF Completion Table Logging
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Automatically logs model completions for qualitative analysis:
|
|
||||||
# - Prompts from your DPO dataset
|
|
||||||
# - Chosen responses (preferred)
|
|
||||||
# - Rejected responses (non-preferred)
|
|
||||||
# - Reward differences
|
|
||||||
#
|
|
||||||
# View in SwanLab dashboard under "rlhf_completions" table
|
|
||||||
|
|
||||||
swanlab_log_completions: true
|
|
||||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
|
||||||
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
|
|
||||||
|
|
||||||
# Buffer size recommendations:
|
|
||||||
# - 128: Default, ~64 KB memory (recommended for most cases)
|
|
||||||
# - 256: ~128 KB memory (this config, good for longer training)
|
|
||||||
# - 512: ~256 KB memory (maximum for very long runs)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Lark (Feishu) Team Notifications
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Get real-time training notifications in your team chat
|
|
||||||
#
|
|
||||||
# Notifications sent for:
|
|
||||||
# - Training start
|
|
||||||
# - Training completion
|
|
||||||
# - Training errors
|
|
||||||
# - Metric milestones (if configured)
|
|
||||||
|
|
||||||
# Recommended: Set via environment variables
|
|
||||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
|
||||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
|
||||||
|
|
||||||
# Or set in config (less secure):
|
|
||||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
|
||||||
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
|
|
||||||
|
|
||||||
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
|
|
||||||
# unauthorized parties from sending fake notifications to your team chat.
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Performance Profiling
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Profiling is automatically enabled when SwanLab is enabled.
|
|
||||||
# Metrics logged to SwanLab under "profiling/" namespace:
|
|
||||||
# profiling/Time taken: AxolotlTrainer.training_step
|
|
||||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
|
||||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
|
||||||
#
|
|
||||||
# Use these metrics to:
|
|
||||||
# - Identify bottlenecks in training loop
|
|
||||||
# - Compare performance across different configurations
|
|
||||||
# - Monitor performance regressions over time
|
|
||||||
# - Debug unexpected slowdowns
|
|
||||||
|
|
||||||
# For custom profiling in your own trainer, see:
|
|
||||||
# examples/swanlab/custom_trainer_profiling.py
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Optional: Private SwanLab Deployment
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# For enterprise users with private SwanLab deployment:
|
|
||||||
|
|
||||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
|
||||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Optional: Model Checkpointing to SwanLab
|
|
||||||
# ------------------------------------------------------------------------------
|
|
||||||
# Log model checkpoints to SwanLab (coming soon)
|
|
||||||
|
|
||||||
swanlab_log_model: false
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Disable Other Logging Tools (Recommended)
|
|
||||||
# ============================================================================
|
|
||||||
# Using multiple logging tools simultaneously can impact performance:
|
|
||||||
# - Expected overhead: ~1-2% per logger
|
|
||||||
# - Potential config/callback conflicts
|
|
||||||
#
|
|
||||||
# For production training, use ONLY SwanLab:
|
|
||||||
|
|
||||||
# wandb_project:
|
|
||||||
# use_wandb: false
|
|
||||||
#
|
|
||||||
# use_mlflow: false
|
|
||||||
#
|
|
||||||
# use_comet: false
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Expected Training Behavior
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# With this configuration, you should see:
|
|
||||||
#
|
|
||||||
# 1. SwanLab Initialization (rank 0 only):
|
|
||||||
# INFO: SwanLab initialized for project: dpo-production
|
|
||||||
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
|
|
||||||
# INFO: SwanLab mode: cloud
|
|
||||||
# INFO: SwanLab workspace: ml-research-team
|
|
||||||
#
|
|
||||||
# 2. Completion Logging (rank 0 only):
|
|
||||||
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
|
|
||||||
# (log_interval=100, max_buffer=256)
|
|
||||||
#
|
|
||||||
# 3. Lark Notifications (rank 0 only):
|
|
||||||
# INFO: Registered Lark notification callback with HMAC authentication
|
|
||||||
#
|
|
||||||
# 4. Distributed Training Detection (if multi-GPU):
|
|
||||||
# INFO: Distributed training detected (world_size=N)
|
|
||||||
# INFO: Only rank 0 will initialize SwanLab
|
|
||||||
# INFO: Other ranks will skip SwanLab to avoid conflicts
|
|
||||||
#
|
|
||||||
# 5. Training Start Notification (Lark):
|
|
||||||
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
|
|
||||||
#
|
|
||||||
# 6. Periodic Completion Logging:
|
|
||||||
# Every 100 steps, completion table is updated in SwanLab dashboard
|
|
||||||
#
|
|
||||||
# 7. Training Complete Notification (Lark):
|
|
||||||
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
|
|
||||||
# With link to SwanLab dashboard and final metrics
|
|
||||||
#
|
|
||||||
# 8. SwanLab Dashboard Shows:
|
|
||||||
# - Training metrics (loss, learning rate, etc.)
|
|
||||||
# - Completion table (rlhf_completions)
|
|
||||||
# - Profiling metrics (profiling/Time taken: ...)
|
|
||||||
# - Hyperparameters and configuration
|
|
||||||
# - System resource usage
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Production Checklist
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# Before deploying to production, verify:
|
|
||||||
# ✅ SwanLab API key is set via environment variable (not in config)
|
|
||||||
# ✅ Lark webhook secret is set (required for HMAC authentication)
|
|
||||||
# ✅ Workspace is set to your team's workspace
|
|
||||||
# ✅ Experiment name is descriptive and unique
|
|
||||||
# ✅ Only SwanLab is enabled (other loggers disabled)
|
|
||||||
# ✅ Completion logging buffer size is appropriate for your training duration
|
|
||||||
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
|
|
||||||
# ✅ Test run completes successfully and shows up in SwanLab dashboard
|
|
||||||
# ✅ Lark notifications are received in team chat
|
|
||||||
# ✅ Profiling metrics are logged correctly
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Troubleshooting
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# If SwanLab initialization fails:
|
|
||||||
# 1. Check SWANLAB_API_KEY environment variable is set
|
|
||||||
# 2. Verify swanlab_project is set in config
|
|
||||||
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
|
|
||||||
# 4. Verify internet connectivity (for cloud mode)
|
|
||||||
|
|
||||||
# If Lark notifications not received:
|
|
||||||
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
|
|
||||||
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
|
|
||||||
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
|
|
||||||
# 4. Check training logs for "Registered Lark notification callback"
|
|
||||||
# 5. Verify bot is added to the target Lark group chat
|
|
||||||
|
|
||||||
# If completions not appearing in SwanLab:
|
|
||||||
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
|
||||||
# 2. Check swanlab_log_completions is true
|
|
||||||
# 3. Wait for log_interval steps (default: 100)
|
|
||||||
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
|
|
||||||
|
|
||||||
# If profiling metrics not appearing:
|
|
||||||
# 1. Verify use_swanlab is true
|
|
||||||
# 2. Check SwanLab is initialized (check logs)
|
|
||||||
# 3. Look under "profiling/" namespace in dashboard
|
|
||||||
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
|
|
||||||
|
|
||||||
# For more help:
|
|
||||||
# - SwanLab docs: https://docs.swanlab.cn
|
|
||||||
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
|
|
||||||
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
# SwanLab LoRA Training Example with Performance Profiling
|
|
||||||
#
|
|
||||||
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
|
|
||||||
# for performance profiling and optimization.
|
|
||||||
#
|
|
||||||
# Features enabled:
|
|
||||||
# - SwanLab experiment tracking
|
|
||||||
# - Performance profiling (training step, forward/backward pass timing)
|
|
||||||
# - Real-time metrics visualization
|
|
||||||
#
|
|
||||||
# To run:
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
|
||||||
|
|
||||||
# Model Configuration
|
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
|
||||||
|
|
||||||
# Dataset Configuration
|
|
||||||
datasets:
|
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-swanlab-profiling-out
|
|
||||||
|
|
||||||
# LoRA Configuration
|
|
||||||
adapter: lora
|
|
||||||
lora_r: 16
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Training Configuration
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: true
|
|
||||||
|
|
||||||
micro_batch_size: 2
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
# Optimization
|
|
||||||
optimizer: adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
# Precision
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
# Performance
|
|
||||||
gradient_checkpointing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
# Checkpointing and Logging
|
|
||||||
logging_steps: 1
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# Loss Monitoring
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: "<|end_of_text|>"
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# SwanLab Integration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
||||||
|
|
||||||
# Basic SwanLab Configuration
|
|
||||||
use_swanlab: true
|
|
||||||
swanlab_project: lora-profiling
|
|
||||||
swanlab_experiment_name: llama-3.2-1b-profiling-demo
|
|
||||||
swanlab_description: "LoRA fine-tuning with performance profiling"
|
|
||||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
|
||||||
|
|
||||||
# SwanLab Authentication
|
|
||||||
# Recommended: Set via environment variable
|
|
||||||
# export SWANLAB_API_KEY=your-api-key
|
|
||||||
# Or set in config (less secure):
|
|
||||||
# swanlab_api_key: your-api-key
|
|
||||||
|
|
||||||
# Optional: Team workspace
|
|
||||||
# swanlab_workspace: my-ml-team
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Performance Profiling
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# SwanLab automatically profiles trainer methods when enabled.
|
|
||||||
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
|
|
||||||
#
|
|
||||||
# Built-in profiling:
|
|
||||||
# - Minimal overhead (< 0.1% per step)
|
|
||||||
# - High-precision timing (microsecond accuracy)
|
|
||||||
# - Exception-safe (logs duration even if method fails)
|
|
||||||
#
|
|
||||||
# View profiling metrics in SwanLab dashboard:
|
|
||||||
# profiling/Time taken: AxolotlTrainer.training_step
|
|
||||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
|
||||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
|
||||||
#
|
|
||||||
# For custom profiling in your own trainer, see:
|
|
||||||
# examples/swanlab/custom_trainer_profiling.py
|
|
||||||
|
|
||||||
# Completion logging is disabled for non-RLHF trainers
|
|
||||||
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Optional: Compare with Multiple Runs
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# To compare profiling metrics across different configurations:
|
|
||||||
#
|
|
||||||
# 1. Run baseline without flash attention:
|
|
||||||
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
|
|
||||||
# flash_attention: false
|
|
||||||
#
|
|
||||||
# 2. Run with gradient checkpointing:
|
|
||||||
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
|
|
||||||
# gradient_checkpointing: true
|
|
||||||
#
|
|
||||||
# 3. Run with both:
|
|
||||||
# swanlab_experiment_name: llama-3.2-1b-optimized
|
|
||||||
# flash_attention: true
|
|
||||||
# gradient_checkpointing: true
|
|
||||||
#
|
|
||||||
# Then compare profiling metrics in SwanLab dashboard to see performance impact
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Optional: Lark (Feishu) Team Notifications
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# Get notified when profiling experiments complete:
|
|
||||||
|
|
||||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
|
||||||
# swanlab_lark_secret: your-webhook-secret
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Profiling Best Practices
|
|
||||||
# ============================================================================
|
|
||||||
#
|
|
||||||
# 1. Run multiple epochs to see profiling trends over time
|
|
||||||
# 2. Ignore first ~10 steps (warmup period, slower)
|
|
||||||
# 3. Look for outliers (steps that take significantly longer)
|
|
||||||
# 4. Compare profiling metrics before/after optimization changes
|
|
||||||
# 5. Monitor per-rank profiling in distributed training
|
|
||||||
#
|
|
||||||
# Common bottlenecks to profile:
|
|
||||||
# - training_step: Overall step time (should be consistent)
|
|
||||||
# - compute_loss: Loss computation (scales with sequence length)
|
|
||||||
# - prediction_step: Evaluation time (can be slow for large val sets)
|
|
||||||
#
|
|
||||||
# If you see inconsistent timing:
|
|
||||||
# - Check for data loading bottlenecks
|
|
||||||
# - Monitor GPU utilization (may be CPU-bound)
|
|
||||||
# - Check for gradient accumulation effects
|
|
||||||
# - Verify CUDA kernel synchronization
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Disable WandB if you're migrating from it
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# wandb_project:
|
|
||||||
# use_wandb: false
|
|
||||||
@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
2. Run the finetuning example:
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
This config uses about 24.9 GiB VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
base_model: arcee-ai/Trinity-Nano-Preview
|
||||||
revision_of_model: 2ee94b0
|
trust_remote_code: true
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user