Compare commits

...

37 Commits

Author SHA1 Message Date
Salman Mohammadi
0a0115493d remove monkeypatch 2026-01-21 10:06:58 +00:00
Salman Mohammadi
7a4f33802d try run regular CE loss on eval 2026-01-19 22:00:48 +00:00
Salman Mohammadi
170dca9bb9 WIP DFT 2026-01-15 18:43:31 +00:00
Wing Lian
d282f32481 don't install deepspeed in arm64 images (#3357) 2026-01-14 12:03:55 -05:00
Wing Lian
6331e4a130 fix amd64 and set 2.9.1 as latest cloud image (#3356) 2026-01-14 11:56:36 -05:00
salman
1410e4474e update PR template (#3349) [skip ci] 2026-01-14 09:39:21 -05:00
Wing Lian
dc77b5bf42 fix arm64 builds (#3355)
* fix syntax  for secrets in gha yaml

* setup env for uv too

* arm64 for base  uv too

* don't build causal-conv1d or mamba for arm64 and use arm64 wheels

* fix dockerfile syntax

* fix shell syntax
2026-01-14 09:38:48 -05:00
NanoCode012
359b7ad85e fix: gemma3_text model loading vision config (#3354)
* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
2026-01-13 09:49:23 -05:00
VED
258ce8d4fa feat : scaled softmax support (#3338)
* scaled softmax

* comment

* lint

* remove egear

* validation for flash

* lint

* val imporve + neet

* fix correct softmax scale val(learned)

* learned scale val 4 ssm

* lint

* fix model_type rmv

* sdpa_atten

* test fix + lint

* test fix

* sdp_a val rmv

* flex fix

* main flash

* lint

* flex attn

* lint comment

* fix score_mod

* Update src/axolotl/utils/schemas/validation.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-13 14:33:11 +07:00
@TT
3e0bbd33ec feat: add ARM64/AArch64 build support to Dockerfile-base (#3346)
* Add support for capability to build arm64 image

* Fixing wrong variable TARGETPLATFORM bug

* Adding missing semicolons

* skip docker hub login if PR (no push) or no credentials

* Enabling arm64 builds for Dockerfile-base in Github actions

* TARGETARCH automatically default to platform arch under build

* Enabling arm64 builds for axolotl docker builds

* Enabling arm64 builds for axolotl-cloud docker build Github actions

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-12 12:00:02 -05:00
salman
4ae6f766ad bump bnb to v0.49.1 (#3351) 2026-01-12 09:42:04 -05:00
VED
e7f0d4ba5b Increased test coverage for lora/qlora (#3147)
* config_val tests

* remove config val(not needed)

* config validation

* parameter freeze validation

* merge/unmerge tests

* removal unwanted

* rename

* lint

* updated lint

* Update tests/utils/lora/test_config_validation_lora.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* pytest skip + mock fix

* nitpicks

* revert some nitpicks

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-06 11:44:48 -05:00
VED
7bf6f70e96 fix total/trainable tokens log (#3344)
* fix total/trainable tokens log

* fix total/trainable tokens log
2026-01-06 09:25:17 -05:00
PraMamba
8aab807e67 feat: Add SwanLab integration for experiment tracking (#3334)
* feat(swanlab): add SwanLab integration for experiment tracking

SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training.

Features:
- Hyperparameter logging
- Training metrics tracking
- RLHF completion logging
- Performance profiling
- Configuration validation and conflict detection

Includes:
- Plugin in src/axolotl/integrations/swanlab/
- Callback in src/axolotl/utils/callbacks/swanlab.py
- Tests in tests/integrations/test_swanlab.py
- Examples in examples/swanlab/

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit

- Change use_swanlab default to True (winglian)
- Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major)
- Add safe exception handling in config fallback (CodeRabbit)
- Use context managers for file operations (CodeRabbit)
- Replace LOG.error with LOG.exception for better debugging (CodeRabbit)
- Sort __all__ alphabetically (CodeRabbit)
- Add language specifiers to README code blocks (CodeRabbit)
- Fix end-of-file newline in README (pre-commit)

Resolves actionable comments and nitpicks from CodeRabbit review.
Addresses reviewer feedback from @winglian.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* only run swanlab integration tests if package is available

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-06 09:19:18 -05:00
Wing Lian
ee59e4de97 add cu130 + torch 2.9.1 to test matrices (#3343)
* add cu130 + torch 2.9.1 to test matrices

* uv can't use pip3 directly
2026-01-05 15:24:29 -05:00
Wing Lian
4e61b8aa23 use updated version of prebuilt wheels for flash attention for cu130 (#3342)
* use updated version of prebuilt wheels for flash attention for cu130

* use elif

* fix the uv base installs of FA also

* make wget less verbose
2026-01-05 13:48:12 -05:00
Wing Lian
b26ba3a5cb don't build images w cuda 130 since we don't have flash attention wheels (#3341) 2026-01-03 18:08:28 -05:00
Wing Lian
afe18ace35 deprecate torch 2.7.1 (#3339) 2026-01-01 06:52:45 -05:00
github-actions[bot]
2b199f9915 chore: update pre-commit hooks (#3340) [skip ci]
Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com>
2026-01-01 06:52:28 -05:00
Wing Lian
e73dab6df9 support pydantic 2.12 (#3328)
* upgrade pydantic to 2.12

* use latest modal version

* upgrade modal

* update modal in requirements and loosen pydantic

* upgrade modal too
2025-12-30 12:41:07 -05:00
VED
f45a97a9ff docs for checkpiont saving (#3335) [skip ci]
Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-30 12:40:32 -05:00
Wing Lian
11c0b5b256 bartch upgrade dependencies (#3299)
* upgrade dependencies

* don't use reset sessions

* downgrade transformers, upgrade other deps

* upgrade bnb to 0.49.0

* restore s3 cache

* explicit use local files w hub

* decompress and strip top level dir

* use 2 levels for strip components

* try to preserve permissions for symlinks

* use updated tar

* fix #3293 for distributed

* downgrade bnb

* fast fail after 4

* fix total tokens device

* patch accelerate CP/SP (#3309)

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-30 09:02:49 -05:00
Wing Lian
66a3de3629 build examples readmes with quarto (#3046)
* build examples readmes with quarto

* chore: formatting

* feat: dynamic build docs

* feat: add more model guides

* chore: format

* fix: collapse sidebar completely to have space for model guides

* fix: security protection for generated qmd

* fix: adjust collapse level, add new models, update links

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-25 19:17:25 +07:00
VED
a6080df73c compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training

* save total_tokens for checkpiont

* check if string

* refactor total_tokens/ num_tokens

* refactor 2

* rplc trainable_step/trian_per_sec_per_gpu

* lint + log trainable/tokens

* consolidate it in the callback.

* test for total_tokes aftr remuse

* check if tokenstate exist after ckpt

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-25 18:38:17 +07:00
NanoCode012
4f5e8a328a Feat: add MiMo and Plano (#3332) [skip-ci]
* feat: add xiaomi's mimo 7b

* fix: pin revision

* fix: update trinity docs and pin revision

* fix: wrong config name

* feat: add vram usage

* feat: add plano

* feat: update plano vram usage

* chore: comments
2025-12-25 18:09:03 +07:00
NanoCode012
418933f0d1 feat: add internvl3_5 (#3141) [skip-ci]
* feat: add internvl3_5

* fix: add timm instructions

* chore: add kimi-linear to cce doc

* feat: update internvl example

* chore: pin revision

* chore: remove from multipack

* fix: add to multimodal array

* fix: internvl use hf version

* feat: update cce

* chore: lint

* fix: list for image_size

* chore: add docs vram usage

* feat: enable cce

* fix: no need trust remote code

* fix: inconsistent timm version
2025-12-25 18:07:59 +07:00
NanoCode012
372f664c63 feat: cleanup old flex mask patch, suppress Matmul bnb warn, and misc (#3330) [skip-ci]
* feat: add pos id to flex attention for packing part 1

* feat: update to include sliding window mask patch

* fix: suppress MatMul8bitLt: inputs will be cast from warnings

* fix: remove redundant flex attention patch

* chore: update olmo docs

* feat: add validator patch for cross entropy
2025-12-25 17:56:20 +07:00
NanoCode012
97f1b1758d Feat: add kimi linear support (#3257)
* feat: add custom kimi linear patch [skip ci]

* feat: add configuration file and fix import [skip ci]

* fix: hijack tokenizer temporarily [skip ci]

* chore: remove accidental commit

* fix: attempt patch kimi remote

* fix: kwargs passsed

* fix: device for tensor

* fix: aux loss calculation

* feat: cleaned up patches order

* fix: remove duplicate tokenizer patch

* chore: add debug logs

* chore: add debug logs

* chore: debug

* Revert "chore: add debug logs"

This reverts commit da372a5f67.

* Revert "chore: add debug logs"

This reverts commit 97d1de1d7c.

* fix: KeyError: 'tokenization_kimi'

* fix: support remote_model_id in cce patch

* feat: add config preload patch

* fix: use standard aux loss calc and updated modeling

* fix: import

* feat: add kimi-linear docs and example

* chore: add note about moe kernels

* feat: update cce to include kimi-linear

* chore: lint

* chore: update main readme

* fix: patch mechanism to address comments

* chore: lint

* fix: tests

* chore: cleanup comment
2025-12-25 17:53:52 +07:00
Abubakar Abid
f2155eaf79 feat: add trackio as experiment tracking integration (#3253)
* feat: add trackio as experiment tracking integration

- Add TrackioConfig to integrations schema with project_name, run_name, and space_id
- Create trackio_.py module for environment setup
- Add is_trackio_available() utility function
- Integrate trackio with report_to in trainer builder
- Add trackio callback for experiment tracking
- Add trackio config keys to gpt-oss example YAMLs
- Trackio runs locally by default, syncs to HF Space if space_id provided

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* Update requirements.txt

* don't allow pydantic 2.12 for now

---------

Co-authored-by: Abubakar Abid <aaabid93@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-23 08:49:07 -05:00
kallewoof
92ee4256f7 feature: raise on long sequence drop (#3321)
* feature: raise on long sequence drop

It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.

* tests: add excess_length_strategy tests

* doc: updated return value description for drop_long_seq_in_dataset

* add @enable_hf_offline

* fixed cfg modified after validate_config called

* hf offline fix

* fix tqdm desc when raise is used

* test: added test for non-batched case

* accidental code change revert

* test: use pytest.raises

* test: simplified drop_seq_len tests

* test: moved excess_length_strat test to test_data.py

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-22 13:59:49 -05:00
Wing Lian
efeb5a4e41 fix check for fp8 capability (#3324)
* fix check for fp8 capability

* handle non-cuda compute

* reduce concurrency of tests
2025-12-22 13:58:25 -05:00
VED
faaff6c792 allow users to set ndigits for rounding of metrics when logging (#3325)
* METRIC_PRECISION-> 8

* use ndigits and move env getter to top of log function

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-22 08:54:43 -05:00
Alexander Kozhevnikov
43cef27458 Fix typo in densemixer RuntimeError (#3327) [skip ci]
It offers installing densemizer while it should be densemixer
2025-12-22 08:53:58 -05:00
Wing Lian
07c41a6c2a fix preview docs failing due to running out of disk (#3326) [skip ci]
* fix preview docs failing due to running out of disk

* fix docs publish too
2025-12-19 11:34:55 -05:00
salman
bbd3486f57 Distributed Muon Optimizer (#3264)
* init

* working

* updating configs

* removing unneeded files

* lint

* comments

* lint

* fix regex match

* bump contribs version

* comments

* fixing tests and imports

* muon imports in test v2

* test cleanup

* bump contribs version

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2025-12-19 10:43:47 -05:00
VED
3750d7dd64 add liger support kernal for dpo (#3302)
* add liger kernal 4 dpo

* revert grpo changes,add support in dpo

* revert grpo changes,add support in dpo

* dpo_use_liger_kernal

* fix liger_dpo

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-18 11:11:06 -05:00
xzuyn
2197b0bf89 feat: cheap ppl metric (#3317)
* Import math and compute perplexity from loss values

* lint

* coderabbit changes

* lint

* fix: add rounding to ppl

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-18 09:02:41 -05:00
124 changed files with 10609 additions and 491 deletions

View File

@@ -15,6 +15,11 @@
<!--- Include details of your testing environment, tests ran to see how --> <!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. --> <!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate) ## Screenshots (if appropriate)
## Types of changes ## Types of changes

View File

@@ -21,31 +21,12 @@ jobs:
timeout-minutes: 480 timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -53,6 +34,15 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -60,6 +50,7 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -67,6 +58,7 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -93,6 +85,7 @@ jobs:
axolotlai/axolotl-base axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -103,6 +96,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
@@ -117,24 +111,12 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480 timeout-minutes: 480
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -142,6 +124,7 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -149,6 +132,15 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -156,6 +148,7 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -167,6 +160,7 @@ jobs:
axolotlai/axolotl-base-uv axolotlai/axolotl-base-uv
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -177,6 +171,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -12,6 +12,9 @@ jobs:
build-deploy: build-deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Quarto - name: Set up Quarto

View File

@@ -15,37 +15,31 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true platforms: "linux/amd64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -71,6 +65,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }}
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -92,43 +87,31 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true platforms: "linux/amd64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -153,6 +136,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }}
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -170,22 +154,16 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.9.1
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: axolotl_extras:
is_latest: is_latest:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
@@ -212,6 +190,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}

View File

@@ -19,6 +19,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
jobs: jobs:
test-axolotl-multigpu: test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -26,13 +29,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
@@ -43,7 +39,14 @@ jobs:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.1
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: fbgemm-gpu axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
@@ -59,7 +62,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -72,4 +75,4 @@ jobs:
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.multigpu modal run -m cicd.multigpu

View File

@@ -12,16 +12,16 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -64,16 +64,16 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -11,6 +11,7 @@ on:
- '_quarto.yml' - '_quarto.yml'
- docs/scripts/generate_config_docs.py - docs/scripts/generate_config_docs.py
- src/axolotl/utils/schemas/**.py - src/axolotl/utils/schemas/**.py
- .github/workflows/preview-docs.yml
permissions: permissions:
checks: write checks: write
@@ -27,6 +28,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }} if: ${{ !github.event.pull_request.draft }}
steps: steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2 max-parallel: 2
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0"] pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -99,17 +99,17 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 128
cuda_version: 12.6.3 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.8.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"
@@ -123,7 +123,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -148,10 +148,10 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 128
cuda_version: 12.6.3 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.9.1
num_gpus: 2 num_gpus: 2
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"
@@ -165,7 +165,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"] pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -66,12 +66,13 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Restore Cache from S3 - name: Restore Cache from S3
# id: hf-cache-restore-s3 id: hf-cache-restore-s3
# run: | run: |
# mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
# ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -111,10 +112,13 @@ jobs:
run: | run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache scan
- name: Run tests - name: Run tests
run: | run: |
df -h df -h
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h df -h
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h df -h
@@ -122,6 +126,9 @@ jobs:
df -h df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache scan
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
with: with:
@@ -138,7 +145,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"] pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -149,12 +156,13 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Restore Cache from S3 - name: Restore Cache from S3
# id: hf-cache-restore-s3 id: hf-cache-restore-s3
# run: | run: |
# mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
# ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -196,10 +204,13 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache scan
gate-skip-e2e: gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist] needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -260,7 +271,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -292,18 +303,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
# - cuda: 128
# cuda_version: 12.8.1
# python_version: "3.11"
# pytorch: 2.7.1
# num_gpus: 1
# axolotl_extras:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
@@ -314,7 +313,13 @@ jobs:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:
@@ -327,7 +332,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -354,10 +359,10 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 126 - cuda: 128
cuda_version: 12.6.3 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:
@@ -370,7 +375,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2 pip install modal==1.3.0.post1 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7 rev: v0.14.10
hooks: hooks:
- id: ruff - id: ruff
args: [--fix] args: [--fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.0 rev: v1.19.1
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:

View File

@@ -29,15 +29,15 @@
## 🎉 Latest Updates ## 🎉 Latest Updates
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3). - 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07: - 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info. - ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm). - Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)! - FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl! - [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! - 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
@@ -46,8 +46,8 @@
<summary>Expand older updates</summary> <summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. - 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! - 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! - 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own! - 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try. - 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun! - 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
@@ -77,7 +77,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU - NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11 - Python 3.11
- PyTorch ≥2.7.1 - PyTorch ≥2.8.0
### Google Colab ### Google Colab

View File

@@ -1,6 +1,8 @@
project: project:
type: website type: website
pre-render: docs/scripts/generate_config_docs.py pre-render:
- docs/scripts/generate_config_docs.py
- docs/scripts/generate_examples_docs.py
quartodoc: quartodoc:
dir: docs/api dir: docs/api
@@ -240,6 +242,46 @@ website:
- docs/getting-started.qmd - docs/getting-started.qmd
- docs/installation.qmd - docs/installation.qmd
- docs/inference.qmd - docs/inference.qmd
- section: "Model Guides"
contents:
- docs/models/kimi-linear.qmd
- docs/models/plano.qmd
- docs/models/mimo.qmd
- docs/models/internvl3_5.qmd
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
- docs/models/ministral3/think.qmd
- docs/models/ministral3/vision.qmd
- section: "Magistral"
contents:
- docs/models/magistral.qmd
- docs/models/magistral/think.qmd
- docs/models/magistral/vision.qmd
- docs/models/ministral.qmd
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
- docs/models/qwen3.qmd
- docs/models/gemma3n.qmd
- docs/models/apertus.qmd
- docs/models/gpt-oss.qmd
- docs/models/seed-oss.qmd
- docs/models/phi.qmd
- docs/models/smolvlm2.qmd
- docs/models/granite4.qmd
- docs/models/LiquidAI.qmd
- docs/models/hunyuan.qmd
- docs/models/jamba.qmd
- docs/models/orpheus.qmd
- docs/cli.qmd - docs/cli.qmd
- docs/telemetry.qmd - docs/telemetry.qmd
- docs/config-reference.qmd - docs/config-reference.qmd

View File

@@ -2,7 +2,7 @@
set -e set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 \ pytest -v --durations=10 -n2 --maxfail=4 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS="" ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$TARGETARCH" = "arm64" ]; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \ fi && \
python scripts/unsloth_install.py | sh && \ if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \ python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \ pip install pytest && \
pip cache purge pip cache purge

View File

@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8" ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10" ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118" ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -22,11 +24,17 @@ RUN apt-get update \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \ librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \ && rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& wget \ && if [ "$TARGETARCH" = "amd64" ]; then \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ MINICONDA_ARCH="x86_64"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
MINICONDA_ARCH="aarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \ && bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
@@ -51,8 +59,34 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \ RUN case "$PYTORCH_VERSION" in \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ 2.9.[0-9]*) \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ if [ "$CUDA" = "128" ]; then \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ if [ "$TARGETARCH" = "amd64" ]; then \
fi WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac

View File

@@ -2,6 +2,7 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION="" ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -31,12 +32,35 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ RUN if [ "$TARGETARCH" = "amd64" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi fi
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

2
docs/.gitignore vendored
View File

@@ -3,3 +3,5 @@ _site/
/api/*.qmd /api/*.qmd
/api/*.html /api/*.html
config-reference.qmd config-reference.qmd
models/**/*.qmd
models/**/*.html

View File

@@ -0,0 +1,86 @@
---
title: "Checkpoint Saving"
format:
html:
toc: true
toc-depth: 2
number-sections: true
execute:
enabled: false
---
## Overview
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
## File-Based Checkpoint Trigger
### Configuration
Enable in your config:
```yaml
dynamic_checkpoint:
enabled: true
check_interval: 100 # Optional: check every N steps (default: 100)
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
```
**Options:**
- `enabled`: `true` to enable (required)
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
### How It Works
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
2. When detected, file is deleted and checkpoint is saved
3. In distributed training, rank 0 broadcasts to synchronize all ranks
### Usage
**Command line:**
```bash
touch /path/to/output_dir/axolotl_checkpoint.save
```
**Programmatic:**
```python
from pathlib import Path
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
```
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
**Custom filename:**
```yaml
dynamic_checkpoint:
enabled: true
trigger_file_path: "my_trigger.save"
```
```bash
touch /path/to/output_dir/my_trigger.save
```
## Control+C (SIGINT) Checkpoint
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
## Best Practices
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
## Example
```yaml
output_dir: ./outputs/lora-out
save_steps: 500 # Scheduled checkpoints
dynamic_checkpoint:
enabled: true
check_interval: 50
```
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).

View File

@@ -32,11 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples: Tags examples:
- `main-base-py3.11-cu128-2.7.1` - `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu126-2.7.1` - `main-base-py3.11-cu128-2.9.1`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
## Main ## Main
@@ -74,15 +71,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu128-2.7.1` - `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu126-2.7.1` - `main-py3.11-cu128-2.9.1`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-latest` - `main-latest`
- `main-20250303-py3.11-cu124-2.6.0` - `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0` - `main-20250303-py3.11-cu126-2.6.0`
- `0.10.1` - `0.12.0`
## Cloud ## Cloud

View File

@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
::: :::
::: {.callout-important} ::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8. For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
::: :::
### PyPI Installation (Recommended) {#sec-pypi} ### PyPI Installation (Recommended) {#sec-pypi}
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
::: :::
::: {.callout-important} ::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`. For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
::: :::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available. Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.

View File

@@ -21,6 +21,7 @@ format:
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2) - [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl) - [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
## Usage ## Usage
@@ -202,6 +203,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
base_model: LiquidAI/LFM2-VL-450M base_model: LiquidAI/LFM2-VL-450M
``` ```
### Intern-VL {#sec-intern-vl}
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.19`
:::
```yaml
base_model: OpenGVLab/InternVL3_5-8B
```
## Dataset Format ## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -0,0 +1,90 @@
examples:
# December 2025
- name: kimi-linear
title: Kimi Linear
- name: plano
title: Plano Orchestrator
- name: mimo
title: MiMo
- name: internvl3_5
title: InternVL 3.5
# AllenAI
- name: olmo3
title: OLMo 3
# ArceeAI
- name: trinity
title: Trinity
- name: arcee
title: Arcee AFM
# MistralAI
- name: ministral3/think
title: Ministral 3 Thinking
- name: ministral3/vision
title: Ministral 3 Vision
- name: magistral/think
title: Magistral Thinking
- name: magistral/vision
title: Magistral Vision
- name: ministral
title: Ministral
- name: mistral-small
title: Mistral Small 3.1/3.2
- name: voxtral
title: Voxtral
- name: devstral
title: Devstral
- name: mistral
title: Mistral 7B
# Meta
- name: llama-4
title: Llama 4
- name: llama-2
title: Llama 2
# Alibaba
- name: qwen3-next
title: Qwen 3 Next
- name: qwen3
title: Qwen 3
# Google
- name: gemma3n
title: Gemma 3n
# Swiss AI
- name: apertus
title: Apertus
# GPT-OSS
- name: gpt-oss
title: GPT-OSS
- name: seed-oss
title: Seed-OSS
# Microsoft
- name: phi
title: Phi
# SmolVLM
- name: smolvlm2
title: SmolVLM 2
# IBM
- name: granite4
title: Granite 4
# LiquidAI
- name: LiquidAI
title: Liquid Foundation Models 2
# Other
- name: hunyuan
title: Hunyuan
- name: jamba
title: Jamba
- name: orpheus
title: Orpheus

View File

@@ -0,0 +1,424 @@
"""
auto generate example docs from allowlist
"""
import re
import shutil
import sys
from pathlib import Path
import yaml
# Paths
THIS = Path(__file__).resolve()
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
EXAMPLES_DIR = ROOT / "examples"
OUTPUT_DIR = ROOT / "docs" / "models"
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
def slugify(name: str) -> str:
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
s = re.sub(r"\s+", "-", s).strip("-").lower()
return s or "example"
def read_allowlist():
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
items = data.get("examples", [])
if not isinstance(items, list):
raise ValueError("`examples` must be a list in examples-allowlist.yml")
return items
def find_readme(folder: Path) -> Path | None:
for name in ("README.md", "Readme.md", "readme.md"):
p = folder / name
if p.exists():
return p
return None
def remove_first_h1(md: str) -> tuple[str, str | None]:
"""
Remove the first H1 from markdown and return (modified_md, h1_title).
The H1 is removed since we use the frontmatter title instead.
"""
lines = md.splitlines()
result = []
h1_title = None
skipped_first = False
for line in lines:
if not skipped_first and line.startswith("# "):
h1_title = line[2:].strip()
skipped_first = True
continue
result.append(line)
return "\n".join(result), h1_title
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
"""
Copy local image assets referenced in markdown to
docs/examples/assets/... and rewrite the links.
"""
dest_assets = dest_assets_root / "assets"
def repl(m):
url = m.group(1).strip()
if re.match(r"^(https?:)?//", url):
return m.group(0) # leave remote URLs
src_path = (src_dir / url).resolve()
if not src_path.exists():
return m.group(0) # leave as-is if not found
rel = src_path.relative_to(src_dir)
# Create a unique asset path based on source directory name
asset_name = src_dir.name.replace("/", "-")
dest_path = dest_assets / asset_name / rel
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
return m.group(0).replace(url, new_rel)
return IMG_RE.sub(repl, md)
def rewrite_readme_links(
md: str,
src_dir: Path,
examples_dir: Path,
parent_index_only: set,
current_src_path: str,
allowlist_entries: set,
current_output_path: str,
) -> str:
"""
Rewrite links between README.md files to point to the correct .qmd files.
"""
def repl(m):
text = m.group(1)
url = m.group(2).strip()
# Skip remote URLs and anchor links
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
return m.group(0)
# Skip non-markdown files
if not url.lower().endswith(".md"):
return m.group(0)
# Resolve the target path
try:
target_path = (src_dir / url).resolve()
# Check if target is outside examples_dir
try:
rel_path = target_path.relative_to(examples_dir)
except ValueError:
# Target is outside examples_dir, leave as-is
return m.group(0)
parts = list(rel_path.parts)
# Determine the output path for the target
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
# This is a README link
if len(parts) == 1:
# Link to root README -> index.qmd
target_output = "index.qmd"
elif len(parts) == 2:
if parts[0] == ".":
# Current directory README
target_output = "index.qmd"
else:
# subdir/README.md
parent_dir = parts[0]
if parent_dir in parent_index_only:
target_output = f"{parent_dir}/index.qmd"
else:
target_output = f"{parent_dir}.qmd"
else:
# Deeper nesting: parent/subdir/README.md
# Build the full path like "parent/subdir"
full_path = "/".join(parts[:-1]) # Remove README.md
# Check if this exact path is in allowlist
if full_path in allowlist_entries:
# This is a sub-entry with its own entry -> use .qmd
target_output = f"{full_path}.qmd"
elif parts[0] == ".":
# ./subdir/README.md -> check if subdir has own entry
subdir = parts[1]
if subdir in parent_index_only:
target_output = f"{subdir}/index.qmd"
else:
target_output = f"{subdir}.qmd"
else:
# parent/subdir where parent doesn't have own entry
target_output = f"{full_path}/index.qmd"
else:
# Regular .md file -> convert to .qmd, keep path structure
target_output = "/".join(parts)[:-2] + "qmd"
# Compute relative path from current output file to target
current_parts = current_output_path.split("/")
target_parts = target_output.split("/")
# Special case: if current is a subdir file and target is a single-component file at root
# Example: current="magistral/vision", target="magistral.qmd"
if len(current_parts) > 1 and len(target_parts) == 1:
# Current is in subdir, target is at root level
# Go up to root: ../ for each level
up_count = len(current_parts) - 1
rel_parts = [".."] * up_count + [target_parts[0]]
new_url = "/".join(rel_parts)
else:
# Find common prefix
i = 0
while (
i < min(len(current_parts) - 1, len(target_parts))
and current_parts[i] == target_parts[i]
):
i += 1
# Build relative path: go up (../) then down to target
up_count = len(current_parts) - 1 - i
rel_parts = [".."] * up_count + target_parts[i:]
if not rel_parts or rel_parts == [".."]:
# Points to same directory or parent
new_url = "/".join(rel_parts) if rel_parts else "."
else:
new_url = "/".join(rel_parts)
return f"[{text}]({new_url})"
except (ValueError, IndexError):
return m.group(0)
return LINK_RE.sub(repl, md)
def write_qmd(out_path: Path, title: str, body_md: str):
out_path.parent.mkdir(parents=True, exist_ok=True)
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
out_path.write_text(fm + body_md, encoding="utf-8")
def update_quarto_yml(generated: list[tuple[str, str, str]]):
"""
Update _quarto.yml with the generated example files in the correct order.
This keeps the sidebar in sync with the allowlist.
Model Guides is now nested under "Getting Started" section.
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
"""
quarto_yml = ROOT / "_quarto.yml"
if not quarto_yml.exists():
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
return
content = quarto_yml.read_text(encoding="utf-8")
# First pass: find all parents that have sub-entries
parents_with_subs = set()
for path, _name, _title in generated:
if "/" in path:
parent = path.split("/")[0]
parents_with_subs.add(parent)
# Build the YAML contents while preserving allowlist order
lines = []
processed_sections = set()
for path, _name, title in generated:
# Check if this is a parent page that has sub-pages
if path in parents_with_subs:
# This is a parent page with sub-pages - create a nested section
if path not in processed_sections:
processed_sections.add(path)
section_title = (
title or path.replace("-", " ").replace("_", " ").title()
)
lines.append(f' - section: "{section_title}"')
lines.append(" contents:")
# Add the parent page first
lines.append(f" - docs/models/{path}.qmd")
# Then add all sub-pages
for sub_path, _sub_name, _sub_title in generated:
if "/" in sub_path and sub_path.split("/")[0] == path:
lines.append(
f" - docs/models/{sub_path}.qmd"
)
elif "/" not in path:
# This is a flat item with no sub-pages
# Skip if it was already included as part of a parent section
if path not in processed_sections:
lines.append(f" - docs/models/{path}.qmd")
yaml_content = "\n".join(lines) + "\n"
# Pattern to match only the Model Guides contents, stopping at the next item
# in Getting Started (lines starting with 12 spaces: same level as the section)
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
def replacement(match):
prefix = match.group(1)
return prefix + "\n" + yaml_content
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
if new_content != content:
quarto_yml.write_text(new_content, encoding="utf-8")
print(f"Updated {quarto_yml}")
else:
print(f"No changes needed for {quarto_yml}")
def main():
allow = read_allowlist()
if not EXAMPLES_DIR.exists():
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
return
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
# First pass: identify which parents have their own entry vs only sub-entries
parent_entries = set() # Parents that have their own entry
parent_with_subs = set() # Parents that have sub-entries
allowlist_entries = set() # All entries in allowlist
for item in allow:
if isinstance(item, str):
name = item
else:
name = item.get("name")
allowlist_entries.add(name)
if "/" in name:
parent = name.split("/")[0]
parent_with_subs.add(parent)
else:
parent_entries.add(name)
# Parents with subs that DON'T have their own entry -> use index.qmd
parent_index_only = parent_with_subs - parent_entries
generated = []
seen_dirs = set() # Track which parent directories we've created index for
for item in allow:
if isinstance(item, str):
name = item
title = None
else:
name = item.get("name")
title = item.get("title")
if not name:
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
continue
src_dir = EXAMPLES_DIR / name
if not src_dir.exists() or not src_dir.is_dir():
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
continue
readme = find_readme(src_dir)
if not readme:
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
continue
md = readme.read_text(encoding="utf-8")
# Determine output path first (needed for link rewriting)
parts = name.split("/")
if len(parts) == 1:
# Simple case: no subdirectory
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
sidebar_path = parts[0]
else:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
child = "-".join(parts[1:]) # handle nested subdirs
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
sidebar_path = f"{parent}/{child}"
# Remove the first H1 (we use frontmatter title instead)
md, _ = remove_first_h1(md)
# Rewrite links between README files
md = rewrite_readme_links(
md,
src_dir,
EXAMPLES_DIR,
parent_index_only,
name,
allowlist_entries,
sidebar_path,
)
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
# Handle parent page generation for sub-entries
if len(parts) > 1:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
# Create parent.qmd if not already done and parent doesn't have own entry
if parent not in seen_dirs and parent in parent_index_only:
parent_readme = find_readme(EXAMPLES_DIR / parent)
if parent_readme:
parent_md = parent_readme.read_text(encoding="utf-8")
parent_md, _ = remove_first_h1(parent_md)
parent_md = rewrite_readme_links(
parent_md,
EXAMPLES_DIR / parent,
EXAMPLES_DIR,
parent_index_only,
parent,
allowlist_entries,
parent,
)
parent_md = rewrite_and_copy_assets(
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
)
parent_title = parent.replace("-", " ").replace("_", " ").title()
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
generated.append((parent, parent, parent_title))
seen_dirs.add(parent)
if not title:
title = name.replace("/", " ").replace("-", " ").title()
write_qmd(out_path, title, md)
generated.append((sidebar_path, name, title))
# Index page - preserve allowlist order
if generated:
listing = "\n".join(
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
)
index_md = (
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
+ listing
+ "\n"
)
index_fm = (
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
)
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
# Auto-update _quarto.yml to keep sidebar in sync
update_quarto_yml(generated)
if __name__ == "__main__":
main()

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
] ]
}, },
{ {

View File

@@ -52,6 +52,7 @@ gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -0,0 +1,53 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma-3-1b-fft-dft
sequence_len: 2048
use_dynamic_finetuning: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too # Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_linear: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:

View File

@@ -32,6 +32,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -29,6 +29,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -0,0 +1,43 @@
# Finetune OpenGV's InternVL with Axolotl
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install `timm` for vision model support:
```bash
pip install timm==1.0.19
```
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
4. Run the finetuning example:
```bash
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
```
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Tips
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,61 @@
base_model: OpenGVLab/InternVL3_5-8B-HF
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,47 @@
# Finetune MoonshotAI's Kimi Linear with Axolotl
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
3. Run the finetuning example:
```bash
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
```
This config uses about 98.7GiB VRAM.
Let us know how it goes. Happy finetuning!
### TIPS
- Kimi Linear requires `trust_remote_code: true`.
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
## Optimization Guides
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
This is not yet compatible with MoE kernels from transformers v5.
## Related Resources
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,81 @@
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
trust_remote_code: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.2
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -29,7 +29,6 @@ flex_attention: true
flex_attn_compile_kwargs: flex_attn_compile_kwargs:
dynamic: false dynamic: false
mode: max-autotune-no-cudagraphs mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true torch_compile: true
wandb_project: wandb_project:

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
## Prerequisites ## Prerequisites
Before starting, ensure you have: Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md)) - Installed Axolotl (see [main README](../README.md))
## Getting Started ## Getting Started

View File

@@ -5,7 +5,8 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
## Prerequisites ## Prerequisites
Before starting, ensure you have: Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
- Installed Axolotl from source (see [main README](../README.md))
## Getting started ## Getting started

39
examples/mimo/README.md Normal file
View File

@@ -0,0 +1,39 @@
# Finetune Xiaomi's MiMo with Axolotl
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the finetuning example:
```bash
axolotl train examples/mimo/mimo-7b-qlora.yaml
```
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Tips
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
## Related Resources
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,67 @@
base_model: XiaomiMiMo/MiMo-7B-RL
trust_remote_code: true
revision_of_model: 6299b5a
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# CCE - N/A as of now
# plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -59,6 +59,7 @@ gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
scaling_softmax: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
## Prerequisites ## Prerequisites
Before starting, ensure you have: Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md)) - Installed Axolotl (see [main README](../README.md))
## Getting Started ## Getting Started

View File

@@ -5,7 +5,8 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
## Prerequisites ## Prerequisites
Before starting, ensure you have: Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
- Installed Axolotl from source (see [main README](../README.md))
## Getting started ## Getting started

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
## Prerequisites ## Prerequisites
Before starting, ensure you have: Before starting, ensure you have:
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html)) - Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
## Getting Started ## Getting Started

View File

@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
axolotl train examples/olmo3/olmo3-7b-qlora.yaml axolotl train examples/olmo3/olmo3-7b-qlora.yaml
``` ```
Let us know how it goes. Happy finetuning! 🚀 This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### TIPS ### TIPS

View File

@@ -42,10 +42,10 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 2
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002

42
examples/plano/README.md Normal file
View File

@@ -0,0 +1,42 @@
# Finetune Katanemo's Plano-Orchestrator with Axolotl
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/plano/plano-4b-qlora.yaml
```
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Orchestration Prompt
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
### Tips
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Plano GitHub](https://github.com/katanemo/plano)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: katanemo/Plano-Orchestrator-4B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
chat_template: qwen3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,70 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-adamw-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: adamw
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# AdamW optimizer settings (standard LR for AdamW)
optimizer: adamw_torch_fused
learning_rate: 0.0002
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

View File

@@ -0,0 +1,70 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-muon-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: muon
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# Muon optimizer settings
optimizer: muon
learning_rate: 0.02
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

285
examples/swanlab/README.md Normal file
View File

@@ -0,0 +1,285 @@
# SwanLab Integration Examples
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
## Examples Overview
### 1. DPO with Completion Logging
**File**: `dpo-swanlab-completions.yml`
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
**Features**:
- Basic SwanLab experiment tracking
- Completion table logging (prompts, chosen/rejected responses, rewards)
- Memory-bounded buffer for long training runs
- Cloud sync configuration
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
```
---
### 2. LoRA with Performance Profiling
**File**: `lora-swanlab-profiling.yml`
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
**Features**:
- SwanLab experiment tracking
- Automatic profiling of trainer methods
- Profiling metrics visualization
- Performance optimization guidance
**Best for**: Engineers optimizing training performance and comparing different configurations
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
```
---
### 3. Full-Featured DPO Production Setup
**File**: `dpo-swanlab-full-featured.yml`
Comprehensive production-ready configuration with ALL SwanLab features enabled.
**Features**:
- Experiment tracking with team workspace
- RLHF completion logging
- Performance profiling
- Lark (Feishu) team notifications
- Private deployment support
- Production checklist and troubleshooting
**Best for**: Production RLHF training with team collaboration
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
```
---
### 4. Custom Trainer Profiling (Python)
**File**: `custom_trainer_profiling.py`
Python code examples showing how to add SwanLab profiling to custom trainers.
**Features**:
- `@swanlab_profile` decorator examples
- Context manager profiling for fine-grained timing
- `ProfilingConfig` for advanced filtering and throttling
- Multiple profiling patterns and best practices
**Best for**: Advanced users creating custom trainers
**Usage**:
```python
from custom_trainer_profiling import CustomTrainerWithProfiling
# See file for detailed examples and patterns
```
---
## Feature Matrix
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|---------|----------|-------------------|-----------|-------------------|----------------|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | (commented) | (commented) |
| lora-swanlab-profiling.yml | ✅ | (disabled) | ✅ (auto) | (commented) | (commented) |
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
---
## Configuration Quick Reference
### Basic SwanLab Setup
```yaml
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # cloud, local, offline, disabled
```
### RLHF Completion Logging
```yaml
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
```
### Lark Team Notifications
```yaml
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret # Required for production
```
### Team Workspace
```yaml
swanlab_workspace: my-research-team
```
### Private Deployment
```yaml
swanlab_web_host: https://swanlab.yourcompany.com
swanlab_api_host: https://api.swanlab.yourcompany.com
```
---
## Authentication
### Recommended: Environment Variable
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
```
### Alternative: Config File (less secure)
```yaml
swanlab_api_key: your-api-key
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret
```
---
## Common Use Cases
### Use Case 1: Migrate from WandB to SwanLab
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
```yaml
use_swanlab: true
use_wandb: false
```
### Use Case 2: Analyze DPO Model Outputs
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
```yaml
swanlab_completion_log_interval: 50 # More frequent for short training
swanlab_completion_log_interval: 200 # Less frequent for long training
```
### Use Case 3: Optimize Training Performance
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
- Baseline: `flash_attention: false, gradient_checkpointing: false`
- Flash Attention: `flash_attention: true`
- Gradient Checkpointing: `gradient_checkpointing: true`
- Both: `flash_attention: true, gradient_checkpointing: true`
Compare profiling metrics in SwanLab dashboard.
### Use Case 4: Production RLHF with Team Collaboration
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
```yaml
swanlab_workspace: ml-team
swanlab_lark_webhook_url: ...
swanlab_lark_secret: ...
```
---
## Viewing Your Experiments
### Cloud Mode
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
**Dashboard sections**:
- **Metrics**: Training loss, learning rate, profiling metrics
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
- **Config**: Hyperparameters and configuration
- **System**: Resource usage (GPU, memory, CPU)
- **Files**: Logged artifacts
### Local Mode
```bash
swanlab watch ./swanlog
# Open browser to http://localhost:5092
```
---
## Troubleshooting
### SwanLab not initializing
```bash
# Check API key
echo $SWANLAB_API_KEY
# Verify SwanLab is installed
pip show swanlab
# Check config
grep -A 5 "use_swanlab" your-config.yml
```
### Completions not appearing
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
- Check `swanlab_log_completions: true`
- Wait for `swanlab_completion_log_interval` steps
- Look for "Registered SwanLab RLHF completion logging" in logs
### Lark notifications not working
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
- Verify `SWANLAB_LARK_SECRET` is set correctly
- Check bot is added to Lark group chat
- Look for "Registered Lark notification callback" in logs
### Profiling metrics not appearing
- Verify `use_swanlab: true`
- Check SwanLab is initialized (look for init log message)
- Profiling metrics are under "profiling/" namespace
- Profiling auto-enabled when SwanLab is enabled
---
## Performance Notes
### Overhead Comparison
| Feature | Overhead per Step | Memory Usage |
|---------|------------------|--------------|
| Basic tracking | < 0.1% | ~10 MB |
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
| Profiling | < 0.1% | ~1 KB |
| **Total** | **< 0.7%** | **~10 MB** |
### Best Practices
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
2. Adjust completion log interval based on training length (100-200 steps)
3. Keep completion buffer size reasonable (128-512)
4. Profile critical path methods first (training_step, compute_loss)
5. Use ProfilingConfig to throttle high-frequency operations
---
## Further Reading
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
---
## Contributing
Found an issue or have an improvement? Please submit a PR or open an issue:
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)

View File

@@ -0,0 +1,299 @@
"""Example: Custom Trainer with SwanLab Profiling
This example demonstrates how to add SwanLab profiling to your custom trainer.
Features:
- @swanlab_profile decorator for automatic profiling
- swanlab_profiling_context for fine-grained profiling
- ProfilingConfig for advanced filtering and throttling
Usage:
1. Create your custom trainer extending AxolotlTrainer
2. Add @swanlab_profile decorators to methods you want to profile
3. Use swanlab_profiling_context for fine-grained profiling within methods
4. Enable SwanLab in your config (use_swanlab: true)
See also:
- examples/swanlab/lora-swanlab-profiling.yml for config
- src/axolotl/integrations/swanlab/profiling.py for implementation
"""
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled.
This trainer demonstrates three profiling patterns:
1. Decorator-based profiling (@swanlab_profile)
2. Context manager profiling (swanlab_profiling_context)
3. Advanced profiling with filtering (ProfilingConfig)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create custom profiling config for high-frequency operations
self.fast_op_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.5, # Only log if duration > 0.5ms
log_interval=50, # Log every 50th call
)
# ========================================================================
# Pattern 1: Decorator-based Profiling
# ========================================================================
# Best for: Methods you always want to profile
# Overhead: ~2-5 microseconds per call (negligible)
@swanlab_profile
def training_step(self, model, inputs):
"""Main training step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
"""
return super().training_step(model, inputs)
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False):
"""Loss computation - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
"""
return super().compute_loss(model, inputs, return_outputs)
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Prediction step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
"""
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# ========================================================================
# Pattern 2: Fine-grained Context Manager Profiling
# ========================================================================
# Best for: Profiling specific code blocks within a method
# Use case: When you want to profile forward vs backward separately
def complex_training_step(self, model, inputs):
"""Training step with fine-grained profiling.
Profiling metrics:
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
"""
# Profile just the forward pass
with swanlab_profiling_context(self, "forward_pass"):
outputs = model(**inputs)
loss = outputs.loss
# Profile just the backward pass
with swanlab_profiling_context(self, "backward_pass"):
loss.backward()
# Profile optimizer step
with swanlab_profiling_context(self, "optimizer_step"):
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
# ========================================================================
# Pattern 3: Advanced Profiling with Filtering
# ========================================================================
# Best for: High-frequency operations where you want to throttle logging
# Use case: Methods called 100+ times per step
def _prepare_inputs(self, inputs):
"""Prepare inputs - throttled profiling.
This method is called frequently (once per batch), so we throttle
profiling to reduce overhead:
- Only log if duration > 0.5ms (skip very fast operations)
- Only log every 50th call (reduce logging frequency)
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
"""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_op_config
):
return super()._prepare_inputs(inputs)
def _prepare_input_for_model(self, input_ids):
"""Another high-frequency operation - throttled profiling.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
"""
with swanlab_profiling_context_advanced(
self, "prepare_input_for_model", config=self.fast_op_config
):
# Your custom input preparation logic
return input_ids
# ========================================================================
# Pattern 4: Exception-safe Profiling
# ========================================================================
# Profiling is exception-safe: duration is logged even if method raises
@swanlab_profile
def potentially_failing_method(self):
"""This method may raise an exception.
SwanLab profiling will still log the duration before re-raising.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
"""
# Do some work
result = self._do_risky_computation()
# If this raises, profiling duration is still logged
if result < 0:
raise ValueError("Invalid result")
return result
def _do_risky_computation(self):
"""Placeholder for risky computation."""
return 42
# ============================================================================
# Advanced Example: Custom ProfilingConfig Per Method
# ============================================================================
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Different profiling configs for different method types
self.critical_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything on critical path
log_interval=1, # Log every call
)
self.fast_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0, # Only log if > 1ms
log_interval=100, # Log every 100th call
)
self.debug_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything
log_interval=1, # Log every call
)
def training_step(self, model, inputs):
"""Critical path - log everything."""
with swanlab_profiling_context_advanced(
self, "training_step", config=self.critical_path_config
):
return super().training_step(model, inputs)
def _prepare_inputs(self, inputs):
"""Fast path - throttle logging."""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_path_config
):
return super()._prepare_inputs(inputs)
def _debug_method(self, data):
"""Debug-only method - verbose logging."""
with swanlab_profiling_context_advanced(
self, "debug_method", config=self.debug_config
):
# Your debug logic
pass
# ============================================================================
# How to Use This Custom Trainer
# ============================================================================
"""
To use this custom trainer:
1. Save this file to your project (e.g., my_custom_trainer.py)
2. Create a config file that uses your custom trainer:
# config.yml
base_model: NousResearch/Llama-3.2-1B
# ... other config ...
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-profiling-experiment
# Optional: Specify custom trainer
# (Or modify axolotl to use your custom trainer class)
3. Run training:
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train config.yml
4. View profiling metrics in SwanLab dashboard:
- profiling/Time taken: CustomTrainerWithProfiling.training_step
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- etc.
5. Compare profiling metrics across runs:
- Run baseline without optimizations
- Run with flash_attention enabled
- Run with gradient_checkpointing enabled
- Compare profiling metrics to see performance impact
"""
# ============================================================================
# Tips for Effective Profiling
# ============================================================================
"""
1. Profile the critical path first:
- training_step, compute_loss, prediction_step
- These methods are called most frequently and have biggest impact
2. Use throttling for high-frequency operations:
- Methods called 100+ times per step
- Use log_interval=50 or log_interval=100
- Reduces profiling overhead and dashboard clutter
3. Filter noise with min_duration_ms:
- Set min_duration_ms=1.0 to skip very fast operations
- Focus on operations that actually take time
4. Compare across runs:
- Run same config multiple times to check consistency
- Compare different optimization strategies
- Track profiling trends over time
5. Monitor distributed training:
- Check for per-rank timing differences
- Look for stragglers (slower ranks)
- Identify synchronization bottlenecks
6. Disable profiling in production:
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
- DEFAULT_PROFILING_CONFIG.enabled = False
7. Exception handling:
- Profiling is exception-safe
- Duration logged even if method raises
- Useful for debugging methods that fail intermittently
"""

View File

@@ -0,0 +1,168 @@
# SwanLab DPO Training Example with Completion Logging
#
# This example demonstrates DPO (Direct Preference Optimization) training
# with SwanLab integration for experiment tracking and completion table logging.
#
# Features enabled:
# - SwanLab experiment tracking
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
# - Lark (Feishu) team notifications (optional)
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
# Model Configuration
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization
load_in_8bit: true
load_in_4bit: false
# LoRA Configuration
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# DPO Configuration
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# Dataset and Output
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-out
# Training Configuration
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# Optimization
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: dpo-training
swanlab_experiment_name: llama-3-dpo-completions-demo
swanlab_description: "DPO training with completion table logging"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-research-team
# ============================================================================
# RLHF Completion Table Logging
# ============================================================================
#
# Automatically logs model completions to SwanLab for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View the table in SwanLab dashboard under "rlhf_completions"
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 training steps
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
# Memory Usage Notes:
# - Buffer size 128: ~64 KB (default, recommended)
# - Buffer size 512: ~256 KB (for more historical completions)
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
# Performance Notes:
# - Completion logging overhead: < 0.5% per training step
# - Only logs every N steps to minimize impact
# - Memory-bounded buffer prevents memory leaks
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get real-time training notifications in your team chat
# Uncomment to enable:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # Recommended for production
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# ============================================================================
# Optional: Private SwanLab Deployment
# ============================================================================
#
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# wandb_entity:
# use_wandb: false

View File

@@ -0,0 +1,329 @@
# SwanLab Full-Featured DPO Training Example
#
# This example demonstrates ALL SwanLab integration features:
# - Experiment tracking with cloud sync
# - RLHF completion table logging
# - Performance profiling
# - Lark (Feishu) team notifications
# - Team workspace collaboration
#
# Use this as a reference for production RLHF training setups.
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
# ============================================================================
# Model Configuration
# ============================================================================
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization for efficient training
load_in_8bit: true
load_in_4bit: false
# ============================================================================
# LoRA Configuration
# ============================================================================
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true # Target all linear layers
# ============================================================================
# DPO (Direct Preference Optimization) Configuration
# ============================================================================
chat_template: llama3
rl: dpo # Enable DPO trainer
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# ============================================================================
# Dataset and Output Configuration
# ============================================================================
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-full-featured-out
# ============================================================================
# Training Configuration
# ============================================================================
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# ============================================================================
# Optimization
# ============================================================================
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# ============================================================================
# Precision and Performance
# ============================================================================
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# ============================================================================
# Checkpointing and Logging
# ============================================================================
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration - Full Configuration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# ------------------------------------------------------------------------------
# Basic SwanLab Configuration
# ------------------------------------------------------------------------------
use_swanlab: true
swanlab_project: dpo-production
swanlab_experiment_name: llama-3-dpo-full-featured-v1
swanlab_description: |
Production DPO training with all SwanLab features enabled:
- Completion table logging for qualitative analysis
- Performance profiling for optimization
- Lark notifications for team collaboration
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# ------------------------------------------------------------------------------
# Team Collaboration
# ------------------------------------------------------------------------------
# Workspace for team collaboration (shared experiments)
swanlab_workspace: ml-research-team
# Authentication (recommended: use environment variable)
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# ------------------------------------------------------------------------------
# RLHF Completion Table Logging
# ------------------------------------------------------------------------------
# Automatically logs model completions for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View in SwanLab dashboard under "rlhf_completions" table
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
# Buffer size recommendations:
# - 128: Default, ~64 KB memory (recommended for most cases)
# - 256: ~128 KB memory (this config, good for longer training)
# - 512: ~256 KB memory (maximum for very long runs)
# ------------------------------------------------------------------------------
# Lark (Feishu) Team Notifications
# ------------------------------------------------------------------------------
# Get real-time training notifications in your team chat
#
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# Recommended: Set via environment variables
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# Or set in config (less secure):
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
# unauthorized parties from sending fake notifications to your team chat.
# ------------------------------------------------------------------------------
# Performance Profiling
# ------------------------------------------------------------------------------
# Profiling is automatically enabled when SwanLab is enabled.
# Metrics logged to SwanLab under "profiling/" namespace:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# Use these metrics to:
# - Identify bottlenecks in training loop
# - Compare performance across different configurations
# - Monitor performance regressions over time
# - Debug unexpected slowdowns
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# ------------------------------------------------------------------------------
# Optional: Private SwanLab Deployment
# ------------------------------------------------------------------------------
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ------------------------------------------------------------------------------
# Optional: Model Checkpointing to SwanLab
# ------------------------------------------------------------------------------
# Log model checkpoints to SwanLab (coming soon)
swanlab_log_model: false
# ============================================================================
# Disable Other Logging Tools (Recommended)
# ============================================================================
# Using multiple logging tools simultaneously can impact performance:
# - Expected overhead: ~1-2% per logger
# - Potential config/callback conflicts
#
# For production training, use ONLY SwanLab:
# wandb_project:
# use_wandb: false
#
# use_mlflow: false
#
# use_comet: false
# ============================================================================
# Expected Training Behavior
# ============================================================================
# With this configuration, you should see:
#
# 1. SwanLab Initialization (rank 0 only):
# INFO: SwanLab initialized for project: dpo-production
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
# INFO: SwanLab mode: cloud
# INFO: SwanLab workspace: ml-research-team
#
# 2. Completion Logging (rank 0 only):
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
# (log_interval=100, max_buffer=256)
#
# 3. Lark Notifications (rank 0 only):
# INFO: Registered Lark notification callback with HMAC authentication
#
# 4. Distributed Training Detection (if multi-GPU):
# INFO: Distributed training detected (world_size=N)
# INFO: Only rank 0 will initialize SwanLab
# INFO: Other ranks will skip SwanLab to avoid conflicts
#
# 5. Training Start Notification (Lark):
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
#
# 6. Periodic Completion Logging:
# Every 100 steps, completion table is updated in SwanLab dashboard
#
# 7. Training Complete Notification (Lark):
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
# With link to SwanLab dashboard and final metrics
#
# 8. SwanLab Dashboard Shows:
# - Training metrics (loss, learning rate, etc.)
# - Completion table (rlhf_completions)
# - Profiling metrics (profiling/Time taken: ...)
# - Hyperparameters and configuration
# - System resource usage
# ============================================================================
# Production Checklist
# ============================================================================
# Before deploying to production, verify:
# ✅ SwanLab API key is set via environment variable (not in config)
# ✅ Lark webhook secret is set (required for HMAC authentication)
# ✅ Workspace is set to your team's workspace
# ✅ Experiment name is descriptive and unique
# ✅ Only SwanLab is enabled (other loggers disabled)
# ✅ Completion logging buffer size is appropriate for your training duration
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
# ✅ Test run completes successfully and shows up in SwanLab dashboard
# ✅ Lark notifications are received in team chat
# ✅ Profiling metrics are logged correctly
# ============================================================================
# Troubleshooting
# ============================================================================
# If SwanLab initialization fails:
# 1. Check SWANLAB_API_KEY environment variable is set
# 2. Verify swanlab_project is set in config
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
# 4. Verify internet connectivity (for cloud mode)
# If Lark notifications not received:
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
# 4. Check training logs for "Registered Lark notification callback"
# 5. Verify bot is added to the target Lark group chat
# If completions not appearing in SwanLab:
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
# 2. Check swanlab_log_completions is true
# 3. Wait for log_interval steps (default: 100)
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
# If profiling metrics not appearing:
# 1. Verify use_swanlab is true
# 2. Check SwanLab is initialized (check logs)
# 3. Look under "profiling/" namespace in dashboard
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
# For more help:
# - SwanLab docs: https://docs.swanlab.cn
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues

View File

@@ -0,0 +1,178 @@
# SwanLab LoRA Training Example with Performance Profiling
#
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
# for performance profiling and optimization.
#
# Features enabled:
# - SwanLab experiment tracking
# - Performance profiling (training step, forward/backward pass timing)
# - Real-time metrics visualization
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
# Model Configuration
base_model: NousResearch/Llama-3.2-1B
# Dataset Configuration
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.1
output_dir: ./outputs/lora-swanlab-profiling-out
# LoRA Configuration
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# Training Configuration
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
# Optimization
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# Loss Monitoring
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
special_tokens:
pad_token: "<|end_of_text|>"
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: lora-profiling
swanlab_experiment_name: llama-3.2-1b-profiling-demo
swanlab_description: "LoRA fine-tuning with performance profiling"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-ml-team
# ============================================================================
# Performance Profiling
# ============================================================================
#
# SwanLab automatically profiles trainer methods when enabled.
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
#
# Built-in profiling:
# - Minimal overhead (< 0.1% per step)
# - High-precision timing (microsecond accuracy)
# - Exception-safe (logs duration even if method fails)
#
# View profiling metrics in SwanLab dashboard:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# Completion logging is disabled for non-RLHF trainers
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
# ============================================================================
# Optional: Compare with Multiple Runs
# ============================================================================
#
# To compare profiling metrics across different configurations:
#
# 1. Run baseline without flash attention:
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
# flash_attention: false
#
# 2. Run with gradient checkpointing:
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
# gradient_checkpointing: true
#
# 3. Run with both:
# swanlab_experiment_name: llama-3.2-1b-optimized
# flash_attention: true
# gradient_checkpointing: true
#
# Then compare profiling metrics in SwanLab dashboard to see performance impact
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get notified when profiling experiments complete:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret
# ============================================================================
# Profiling Best Practices
# ============================================================================
#
# 1. Run multiple epochs to see profiling trends over time
# 2. Ignore first ~10 steps (warmup period, slower)
# 3. Look for outliers (steps that take significantly longer)
# 4. Compare profiling metrics before/after optimization changes
# 5. Monitor per-rank profiling in distributed training
#
# Common bottlenecks to profile:
# - training_step: Overall step time (should be consistent)
# - compute_loss: Loss computation (scales with sequence length)
# - prediction_step: Evaluation time (can be slow for large val sets)
#
# If you see inconsistent timing:
# - Check for data loading bottlenecks
# - Monitor GPU utilization (may be CPU-bound)
# - Check for gradient accumulation effects
# - Verify CUDA kernel synchronization
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# use_wandb: false

View File

@@ -29,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources ## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto) - [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,5 +1,6 @@
base_model: arcee-ai/Trinity-Nano-Preview base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.48.2 bitsandbytes==0.49.1
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -14,20 +14,21 @@ huggingface_hub>=0.36.0
peft>=0.18.0 peft>=0.18.0
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==4.57.1 transformers==4.57.1
accelerate==1.11.0 accelerate==1.12.0
datasets==4.4.1 datasets==4.4.2
deepspeed>=0.17.0 deepspeed>=0.18.3
trl==0.25.0 trl==0.25.1
hf_xet==1.2.0 hf_xet==1.2.0
kernels>=0.9.0 kernels==0.11.5
trackio trackio>=0.13.0
typing-extensions>=4.15.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
sentencepiece sentencepiece
gradio==5.49.1 gradio>=6.2.0,<7.0
modal==1.0.2 modal==1.3.0.post1
pydantic>=2.10.6 pydantic>=2.10.6
addict addict
fire fire
@@ -67,8 +68,7 @@ openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7 axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.5 axolotl-contribs-mit==0.0.6
# telemetry # telemetry
posthog==6.7.11 posthog==6.7.11

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
) )

View File

@@ -156,7 +156,7 @@ extras_require = {
"came_pytorch==0.1.3", "came_pytorch==0.1.3",
], ],
"ray": [ "ray": [
"ray[train]", "ray[train]>=2.52.1",
], ],
"vllm": [ "vllm": [
"vllm==0.10.0", "vllm==0.10.0",

View File

@@ -24,8 +24,7 @@ if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args) launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job # 1. Define a base image for your training job
# must use torch 2.7.0 for vllm BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu128-2.9.1"
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job # 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a # This includes start commands and environment variables.a

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1" docker_tag = "main-py3.11-cu128-2.9.1"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.tee import prepare_debug_log from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trackio_ import setup_trackio_env_vars
from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -227,6 +228,7 @@ def load_cfg(
cfg, cfg,
capabilities={ capabilities={
"bf16": is_torch_bf16_gpu_available(), "bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
@@ -245,6 +247,7 @@ def load_cfg(
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
setup_trackio_env_vars(cfg)
plugin_set_cfg(cfg) plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
@@ -259,3 +262,11 @@ def load_cfg(
) )
return cfg return cfg
def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
except RuntimeError:
return False

View File

@@ -288,8 +288,8 @@ def do_inference_gradio(
title=cfg.get("gradio_title", "Axolotl Gradio Interface"), title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
) )
demo.queue().launch( demo.launch(
show_api=False, footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True), share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None), server_port=cfg.get("gradio_server_port", None),

View File

@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
outputs=[masked_preview, html_out], outputs=[masked_preview, html_out],
) )
demo.queue().launch( demo.launch(
show_api=False, footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True), share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None), server_port=cfg.get("gradio_server_port", None),

View File

@@ -1,158 +0,0 @@
"""
monkeypatch for flex + packing
"""
import sys
from typing import Callable, Optional, Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import Cache, PretrainedConfig
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_preprocess_mask_arguments,
and_masks,
causal_mask_function,
or_masks,
)
from transformers.utils import is_torch_greater_or_equal
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
def create_causal_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
to what is needed in the `modeling_xxx.py` files).
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`torch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
and_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if (
past_key_values
and hasattr(past_key_values, "is_sliding")
and False in past_key_values.is_sliding
):
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
original_attention_mask = (
None
if attention_mask is None
else attention_mask.clone().to(cache_position.device)
)
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
)
if early_exit:
return attention_mask
batch_size, total_seq_len = cache_position.shape
key_length = total_seq_len
document_ids = torch.nn.functional.pad(
original_attention_mask, value=0, pad=(0, key_length)
)
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask_ = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
final_mask = causal_mask_ & document_mask
return final_mask
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (
not past_key_values.is_compileable if past_key_values is not None else True
)
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
batch_size=batch_size,
cache_position=cache_position,
kv_length=kv_length,
kv_offset=kv_offset,
mask_function=mask_factory_function,
attention_mask=attention_mask,
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
return causal_mask
def patch_create_causal_mask(model_type):
import transformers.masking_utils
transformers.masking_utils.create_causal_mask = create_causal_mask
if model_type:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(module_path)
module.create_causal_mask = create_causal_mask
del sys.modules[module_path]
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -35,6 +35,7 @@ from axolotl.utils import (
is_comet_available, is_comet_available,
is_mlflow_available, is_mlflow_available,
is_opentelemetry_available, is_opentelemetry_available,
is_trackio_available,
) )
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
GCCallback, GCCallback,
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_trackio and is_trackio_available():
from axolotl.utils.callbacks.trackio_ import (
SaveAxolotlConfigtoTrackioCallback,
)
callbacks.append(
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available(): if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import ( from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback, OpenTelemetryMetricsCallback,
@@ -281,11 +290,22 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon": if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( _, device_mesh = build_parallelism_config(self.cfg)
MuonOptimizerFactory,
) if device_mesh is not None:
from axolotl.contribs.mit.muon.dist_muon import (
DistMuonOptimizerFactory,
)
optimizer_cls = DistMuonOptimizerFactory
optimizer_kwargs["device_mesh"] = device_mesh
else:
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion": elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( from axolotl.contribs.mit.dion import (
@@ -423,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
report_to.append("tensorboard") report_to.append("tensorboard")
if self.cfg.use_comet: if self.cfg.use_comet:
report_to.append("comet_ml") report_to.append("comet_ml")
if self.cfg.use_trackio:
report_to.append("trackio")
training_args_kwargs["report_to"] = report_to training_args_kwargs["report_to"] = report_to
@@ -430,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["run_name"] = self.cfg.wandb_name training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow: elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
elif self.cfg.use_trackio:
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
else: else:
training_args_kwargs["run_name"] = None training_args_kwargs["run_name"] = None

View File

@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.include_tkps: if self.cfg.include_tkps:
callbacks.append( callbacks.append(
TokensPerSecondCallback( TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
) )
) )
return callbacks return callbacks
@@ -371,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import dft_loss
trainer_kwargs["compute_loss_func"] = dft_loss
trainer_cls = self._get_trainer_cls() trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(

View File

@@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import json
import math
import os import os
from collections import defaultdict from collections import defaultdict
from functools import partial, wraps from functools import partial, wraps
@@ -49,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__) LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = { REDUCTION_FNS = {
"mean": torch.mean, "mean": torch.mean,
"min": torch.min, "min": torch.min,
@@ -348,24 +352,33 @@ class AxolotlTrainer(
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation # track number of tokens for tokens per second calculation
if self.args.include_tkps: if self.args.include_tkps and model.training:
inputs_key = "labels" if "labels" in inputs else "input_ids" inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum() trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
if is_distributed(): if is_distributed():
torch.distributed.all_reduce( torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM trainable_tokens, op=torch.distributed.ReduceOp.SUM
) )
if hasattr(self.state, "num_tokens"): torch.distributed.all_reduce(
self.state.num_tokens = ( total_tokens, op=torch.distributed.ReduceOp.SUM
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
) )
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"): if not hasattr(self.state, "tokens"):
self.state.total_tokens += num_tokens self.state.tokens = {
else: "trainable": torch.zeros(1),
self.state.total_tokens = num_tokens "total": torch.zeros(1),
}
# trainable tokens for throughput and total token slots for summaries
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(
@@ -603,6 +616,7 @@ class AxolotlTrainer(
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
for key, metric_data in self._stored_metrics[train_eval].items(): for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
@@ -613,7 +627,18 @@ class AxolotlTrainer(
raise NotImplementedError( raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]" "Metric reduction must be one of [mean, min, max, sum]"
) )
logs[key] = round(fn(values).item(), 4) logs[key] = round(fn(values).item(), metric_ndigits)
if "loss" in logs:
try:
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
except OverflowError:
logs["ppl"] = float("inf")
if "eval_loss" in logs:
try:
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
except OverflowError:
logs["eval_ppl"] = float("inf")
if is_main_process(): if is_main_process():
# Add memory usage # Add memory usage
@@ -625,17 +650,20 @@ class AxolotlTrainer(
except (ValueError, TypeError, FileNotFoundError): except (ValueError, TypeError, FileNotFoundError):
pass pass
if self.args.include_tkps and train_eval == "train": if (
self.args.include_tkps
and train_eval == "train"
and hasattr(self.state, "tokens")
):
# each rank will log its own tokens per second # each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric # for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round( logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
) )
if ( if "total" in self.state.tokens:
hasattr(self.state, "total_tokens") logs["tokens/total"] = int(self.state.tokens["total"].item())
and self.state.total_tokens is not None if "trainable" in self.state.tokens:
): logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
@@ -670,6 +698,19 @@ class AxolotlTrainer(
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Save total_tokens state if tracking is enabled
if self.args.include_tkps and hasattr(self.state, "tokens"):
tokens_state = {
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
"trainable": int(
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
),
}
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)
return super()._save_checkpoint(model, trial, **kwargs) return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged

View File

@@ -36,4 +36,6 @@ class DPOStrategy:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None: if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs return training_args_kwargs

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
``` ```
## Usage ## Usage
@@ -54,6 +54,8 @@ plugins:
- granitemoehybrid - granitemoehybrid
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- internvl
- kimi_linear
- lfm2 - lfm2
- lfm2_moe - lfm2_moe
- lfm2_vl - lfm2_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
) )
@@ -96,7 +96,11 @@ class CutCrossEntropyPlugin(BasePlugin):
) )
# The patch checks model_type internally # The patch checks model_type internally
cce_patch(cfg.model_config_type)
cce_patch(
cfg.model_config_type,
remote_model_id=cfg.base_model if cfg.trust_remote_code else None,
)
def patch_llama_like( def patch_llama_like(
self, self,
@@ -107,7 +111,9 @@ class CutCrossEntropyPlugin(BasePlugin):
""" """
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(maybe_model, patch_options, model_type: str): def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None
):
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward

View File

@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
if cfg.dense_mixer: if cfg.dense_mixer:
if not importlib.util.find_spec("densemixer"): if not importlib.util.find_spec("densemixer"):
raise RuntimeError( raise RuntimeError(
"DenseMixer is not installed. Install it with `pip install densemizer`" "DenseMixer is not installed. Install it with `pip install densemixer`"
) )
from densemixer.patching import ( from densemixer.patching import (

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""SwanLab integration plugin for Axolotl"""
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
__all__ = ["SwanLabConfig", "SwanLabPlugin"]

View File

@@ -0,0 +1,140 @@
"""SwanLab configuration arguments"""
from pydantic import BaseModel, Field, field_validator, model_validator
class SwanLabConfig(BaseModel):
"""SwanLab configuration subset"""
use_swanlab: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable SwanLab experiment tracking and visualization"
},
)
swanlab_project: str | None = Field(
default=None,
json_schema_extra={"description": "Your SwanLab project name"},
)
swanlab_experiment_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
)
swanlab_description: str | None = Field(
default=None,
json_schema_extra={"description": "Description for your SwanLab experiment"},
)
swanlab_mode: str | None = Field(
default=None,
json_schema_extra={
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
},
)
swanlab_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab workspace name (organization or username)"
},
)
swanlab_api_key: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
},
)
swanlab_log_model: bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
},
)
swanlab_web_host: str | None = Field(
default=None,
json_schema_extra={
"description": "Web address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_api_host: str | None = Field(
default=None,
json_schema_extra={
"description": "API address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_lark_webhook_url: str | None = Field(
default=None,
json_schema_extra={
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
},
)
swanlab_lark_secret: str | None = Field(
default=None,
json_schema_extra={
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
},
)
swanlab_log_completions: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
},
)
swanlab_completion_log_interval: int | None = Field(
default=100,
json_schema_extra={
"description": "Number of training steps between completion table logging to SwanLab"
},
)
swanlab_completion_max_buffer: int | None = Field(
default=128,
json_schema_extra={
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
},
)
@field_validator("swanlab_mode")
@classmethod
def validate_swanlab_mode(cls, v):
"""Validate swanlab_mode is one of the allowed values."""
if v is None:
return v
valid_modes = ["cloud", "local", "offline", "disabled"]
if v not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{v}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Examples:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
f" swanlab_mode: offline # Save metadata locally\n"
f" swanlab_mode: disabled # Turn off SwanLab\n"
)
return v
@field_validator("swanlab_project")
@classmethod
def validate_swanlab_project(cls, v):
"""Validate swanlab_project is non-empty when provided."""
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
raise ValueError(
"swanlab_project cannot be an empty string.\n\n"
"Either:\n"
" 1. Provide a valid project name: swanlab_project: my-project\n"
" 2. Remove the swanlab_project field entirely\n"
)
return v
@model_validator(mode="after")
def validate_swanlab_enabled_requires_project(self):
"""Validate that if use_swanlab is True, swanlab_project must be set."""
if self.use_swanlab is True and not self.swanlab_project:
raise ValueError(
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"Example:\n"
" use_swanlab: true\n"
" swanlab_project: my-llm-training\n"
)
return self

View File

@@ -0,0 +1,179 @@
"""SwanLab callbacks for Axolotl trainers.
This module provides HuggingFace Trainer callbacks for logging
RLHF completions to SwanLab.
"""
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SwanLabRLHFCompletionCallback(TrainerCallback):
"""Callback for logging RLHF completions to SwanLab.
This callback periodically logs model completions (prompts, chosen/rejected
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
Supports DPO, KTO, ORPO, and GRPO trainers.
Example usage:
>>> callback = SwanLabRLHFCompletionCallback(
... log_interval=100, # Log every 100 steps
... max_completions=128, # Keep last 128 completions
... )
>>> trainer.add_callback(callback)
Attributes:
logger: CompletionLogger instance
log_interval: Number of steps between SwanLab logging
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
"""
def __init__(
self,
log_interval: int = 100,
max_completions: int = 128,
table_name: str = "rlhf_completions",
):
"""Initialize SwanLab RLHF completion callback.
Args:
log_interval: Log to SwanLab every N steps. Default: 100
max_completions: Maximum completions to buffer. Default: 128
table_name: SwanLab table name. Default: "rlhf_completions"
"""
super().__init__()
self.logger = CompletionLogger(maxlen=max_completions)
self.log_interval = log_interval
self.table_name = table_name
self.trainer_type: str | None = None # Auto-detected
self._last_logged_step = 0
def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Detect trainer type on initialization."""
trainer = kwargs.get("trainer")
if trainer is not None:
trainer_name = trainer.__class__.__name__
if "DPO" in trainer_name:
self.trainer_type = "dpo"
elif "KTO" in trainer_name:
self.trainer_type = "kto"
elif "ORPO" in trainer_name:
self.trainer_type = "orpo"
elif "GRPO" in trainer_name:
self.trainer_type = "grpo"
else:
self.trainer_type = "unknown"
LOG.info(
f"SwanLab RLHF completion logging enabled for {trainer_name} "
f"(type: {self.trainer_type})"
)
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict | None = None,
**kwargs,
):
"""Capture completions from logs and buffer them.
Different trainers log completions in different formats:
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
- GRPO: logs['grpo/completion'], logs['grpo/reward']
Note: This is a placeholder implementation. Actual log keys depend
on the TRL trainer implementation. You may need to patch the trainers
to expose completion data in logs.
"""
if logs is None or self.trainer_type is None:
return
step = state.global_step
# DPO completions
if self.trainer_type == "dpo":
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
self.logger.add_dpo_completion(
step=step,
prompt=logs.get("dpo/prompt", ""),
chosen=logs.get("dpo/chosen", ""),
rejected=logs.get("dpo/rejected", ""),
reward_diff=logs.get("dpo/reward_diff"),
)
# KTO completions
elif self.trainer_type == "kto":
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
self.logger.add_kto_completion(
step=step,
prompt=logs.get("kto/prompt", ""),
completion=logs.get("kto/completion", ""),
label=logs.get("kto/label", False),
reward=logs.get("kto/reward"),
)
# ORPO completions
elif self.trainer_type == "orpo":
if all(
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
):
self.logger.add_orpo_completion(
step=step,
prompt=logs.get("orpo/prompt", ""),
chosen=logs.get("orpo/chosen", ""),
rejected=logs.get("orpo/rejected", ""),
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
)
# GRPO completions
elif self.trainer_type == "grpo":
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
self.logger.add_grpo_completion(
step=step,
prompt=logs.get("grpo/prompt", ""),
completion=logs.get("grpo/completion", ""),
reward=logs.get("grpo/reward"),
advantage=logs.get("grpo/advantage"),
)
# Periodically log to SwanLab
if step - self._last_logged_step >= self.log_interval:
if len(self.logger) > 0:
self.logger.log_to_swanlab(table_name=self.table_name)
self.logger.clear()
self._last_logged_step = step
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Log remaining completions at end of training."""
if len(self.logger) > 0:
LOG.info(
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
)
self.logger.log_to_swanlab(table_name=self.table_name)
self._last_logged_step = state.global_step

View File

@@ -0,0 +1,228 @@
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
This module provides utilities for logging model completions during
preference training to SwanLab for qualitative analysis.
"""
from collections import deque
from collections.abc import Mapping
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CompletionLogger:
"""Memory-bounded logger for RLHF completions.
Stores prompts, completions, and rewards in fixed-size deques to prevent
memory leaks during long training runs. Logs completion tables to SwanLab
for qualitative analysis of model outputs.
Example usage:
>>> logger = CompletionLogger(maxlen=128)
>>> logger.add_dpo_completion(
... step=0,
... prompt="What is AI?",
... chosen="Artificial Intelligence is...",
... rejected="AI means...",
... reward_diff=0.5
... )
>>> logger.log_to_swanlab()
Attributes:
maxlen: Maximum number of completions to store (older ones are dropped)
data: Deque storing completion dictionaries
"""
def __init__(self, maxlen: int = 128):
"""Initialize completion logger with bounded buffer.
Args:
maxlen: Maximum number of completions to store. When the buffer
is full, oldest completions are automatically discarded.
Default: 128 (sufficient for most RLHF runs without memory issues)
"""
self.maxlen = maxlen
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
def add_dpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
reward_diff: float | None = None,
) -> None:
"""Add a DPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
reward_diff: Reward difference (chosen - rejected), if available
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if reward_diff is not None:
entry["reward_diff"] = reward_diff
self.data.append(entry)
def add_kto_completion(
self,
step: int,
prompt: str,
completion: str,
label: bool,
reward: float | None = None,
) -> None:
"""Add a KTO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
label: True if desirable, False if undesirable
reward: Reward score, if available
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
"label": "desirable" if label else "undesirable",
}
if reward is not None:
entry["reward"] = reward
self.data.append(entry)
def add_orpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
log_odds_ratio: float | None = None,
) -> None:
"""Add an ORPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
log_odds_ratio: Log odds ratio between chosen and rejected
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if log_odds_ratio is not None:
entry["log_odds_ratio"] = log_odds_ratio
self.data.append(entry)
def add_grpo_completion(
self,
step: int,
prompt: str,
completion: str,
reward: float | None = None,
advantage: float | None = None,
) -> None:
"""Add a GRPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
reward: Reward score from reward model
advantage: Advantage estimate (reward - baseline)
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
}
if reward is not None:
entry["reward"] = reward
if advantage is not None:
entry["advantage"] = advantage
self.data.append(entry)
def log_to_swanlab(self, table_name: str = "completions") -> bool:
"""Log buffered completions to SwanLab as a table.
Creates a SwanLab echarts Table with all buffered completions.
Only logs if SwanLab is initialized and data is available.
Args:
table_name: Name of the table in SwanLab dashboard.
Default: "completions"
Returns:
True if logging succeeded, False otherwise
"""
if not self.data:
LOG.debug("No completions to log to SwanLab")
return False
try:
import swanlab
if swanlab.get_run() is None:
LOG.debug("SwanLab not initialized, skipping completion logging")
return False
# Convert deque to list of dicts
completions = list(self.data)
# Extract headers from first entry (all entries should have same structure)
headers = list(completions[0].keys())
# Build rows: each completion becomes one row
rows = []
for completion in completions:
row = [completion.get(header, "") for header in headers]
rows.append(row)
# Log to SwanLab as echarts Table
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
return True
except ImportError:
LOG.warning(
"SwanLab not installed, cannot log completions. "
"Install with: pip install swanlab"
)
return False
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to log completions to SwanLab: %s", err)
return False
def clear(self) -> None:
"""Clear all buffered completions."""
self.data.clear()
def __len__(self) -> int:
"""Return number of buffered completions."""
return len(self.data)
def __repr__(self) -> str:
"""String representation showing buffer status."""
return (
f"CompletionLogger(maxlen={self.maxlen}, "
f"buffered={len(self.data)}/{self.maxlen})"
)

View File

@@ -0,0 +1,554 @@
"""SwanLab Plugin for Axolotl"""
from __future__ import annotations
from typing import TYPE_CHECKING
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainerCallback
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
class SwanLabPlugin(BasePlugin):
"""
SwanLab integration plugin for Axolotl.
Provides experiment tracking, visualization, and logging capabilities
using SwanLab (https://swanlab.cn).
Usage in config.yaml:
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
"""
def __init__(self):
super().__init__()
self.swanlab_initialized = False
LOG.info("SwanLab plugin initialized")
def get_input_args(self) -> str:
"""Returns the configuration model for SwanLab integration."""
return "axolotl.integrations.swanlab.SwanLabConfig"
def register(self, cfg: dict):
"""Register SwanLab plugin with configuration and conflict detection."""
LOG.info("Registering SwanLab plugin")
# === Conflict Detection: Required Fields ===
# Check if SwanLab is enabled
if cfg.get("use_swanlab"):
# 1. Validate project name is set
if not cfg.get("swanlab_project"):
raise ValueError(
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"See: src/axolotl/integrations/swanlab/README.md for examples"
)
# 2. Validate swanlab_mode value
valid_modes = ["cloud", "local", "offline", "disabled"]
mode = cfg.get("swanlab_mode")
if mode and mode not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{mode}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Example:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
)
# 3. Check API key for cloud mode
import os
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
if mode == "cloud":
api_key = cfg.get("swanlab_api_key") or os.environ.get(
"SWANLAB_API_KEY"
)
if not api_key:
LOG.warning(
"SwanLab cloud mode enabled but no API key found.\n"
"SwanLab may fail to initialize during training.\n\n"
"Solutions:\n"
" 1. Set SWANLAB_API_KEY environment variable:\n"
" export SWANLAB_API_KEY=your-api-key\n"
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
" 3. Run 'swanlab login' before training\n"
" 4. Use 'swanlab_mode: local' for offline tracking\n"
)
# === Conflict Detection: Multi-Logger Performance Warning ===
# Detect all active logging tools
active_loggers = []
if cfg.get("use_wandb"):
active_loggers.append("WandB")
if cfg.get("use_mlflow"):
active_loggers.append("MLflow")
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
active_loggers.append("Comet")
if cfg.get("use_swanlab"):
active_loggers.append("SwanLab")
if len(active_loggers) > 1:
LOG.warning(
f"\n{'=' * 70}\n"
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
f"{'=' * 70}\n"
f"This may cause:\n"
f" - Performance overhead (~1-2% per logger, cumulative)\n"
f" - Increased memory usage\n"
f" - Longer training time per step\n"
f" - Potential config/callback conflicts\n\n"
f"Recommendations:\n"
f" - Choose ONE primary logging tool for production training\n"
f" - Use multiple loggers only for:\n"
f" * Migration period (transitioning between tools)\n"
f" * Short comparison runs\n"
f" * Debugging specific tool issues\n"
f" - Monitor system resources (CPU, memory) during training\n"
f"{'=' * 70}\n"
)
if len(active_loggers) >= 3:
LOG.error(
f"\n{'!' * 70}\n"
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
f"{'!' * 70}\n"
f"This is likely unintentional and WILL significantly impact performance.\n"
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
f"STRONGLY RECOMMEND:\n"
f" - Disable all but ONE logging tool\n"
f" - Use config inheritance to manage multiple configs\n"
f"{'!' * 70}\n"
)
# === Auto-Enable Logic ===
# Enable SwanLab if project is specified
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
cfg["use_swanlab"] = True
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
def pre_model_load(self, cfg: DictDefault):
"""Initialize SwanLab before model loading with runtime checks."""
if not cfg.use_swanlab:
return
# === Runtime Check: Import Availability ===
try:
import swanlab
except ImportError as err:
raise ImportError(
"SwanLab is not installed.\n\n"
"Install with:\n"
" pip install swanlab\n\n"
"Or add to requirements:\n"
" swanlab>=0.3.0\n\n"
f"Original error: {err}"
) from err
# Log SwanLab version
try:
swanlab_version = swanlab.__version__
LOG.info(f"SwanLab version: {swanlab_version}")
except AttributeError:
LOG.warning("Could not determine SwanLab version")
# === Runtime Check: Distributed Training Setup ===
from axolotl.utils.distributed import get_world_size, is_main_process
world_size = get_world_size()
if world_size > 1:
mode = getattr(cfg, "swanlab_mode", "cloud")
LOG.info(
f"\n{'=' * 70}\n"
f"Distributed training detected (world_size={world_size})\n"
f"SwanLab mode: {mode}\n"
f"{'=' * 70}\n"
f"Behavior:\n"
f" - Only rank 0 will initialize SwanLab\n"
f" - Other ranks will skip SwanLab to avoid conflicts\n"
)
if mode == "cloud":
LOG.info(
f" - Only rank 0 will upload to SwanLab cloud\n"
f" - Other ranks run without SwanLab overhead\n"
f"{'=' * 70}\n"
)
# Only initialize SwanLab on the main process (rank 0)
# to avoid creating multiple runs in distributed training
if not is_main_process():
LOG.debug("Skipping SwanLab initialization on non-main process")
return
# Initialize SwanLab run (passing all params directly to init)
try:
init_kwargs = self._get_swanlab_init_kwargs(cfg)
swanlab.init(**init_kwargs)
self.swanlab_initialized = True
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
# Register Lark notification callback (if configured)
self._register_lark_callback(cfg)
# Log configuration (with error handling)
try:
config_dict = self._prepare_config_for_logging(cfg)
swanlab.config.update(config_dict)
LOG.debug("Successfully logged config to SwanLab")
except Exception as config_err: # pylint: disable=broad-except
LOG.warning(
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to initialize SwanLab: %s", err)
self.swanlab_initialized = False
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
"""Add SwanLab callbacks before trainer creation."""
callbacks: list[TrainerCallback] = []
if not cfg.use_swanlab:
return callbacks
if not self.swanlab_initialized:
LOG.warning("SwanLab not initialized, skipping callback registration")
return callbacks
try:
from axolotl.utils.callbacks.swanlab import (
CustomSwanLabCallback,
SaveAxolotlConfigtoSwanLabCallback,
)
# Add our custom lightweight SwanLabCallback
# (avoids omegaconf/antlr4 version conflicts)
swanlab_callback = CustomSwanLabCallback()
callbacks.append(swanlab_callback)
LOG.info("Added CustomSwanLabCallback for metrics logging")
# Add Axolotl config logging callback
if cfg.axolotl_config_path:
config_callback = SaveAxolotlConfigtoSwanLabCallback(
cfg.axolotl_config_path
)
callbacks.append(config_callback)
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
except ImportError as err:
LOG.exception("Failed to import SwanLab callbacks: %s", err)
return callbacks
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Post-trainer creation hook."""
if cfg.use_swanlab and self.swanlab_initialized:
try:
import swanlab
# Log additional trainer information (with safe conversion)
trainer_config = {
"total_steps": int(trainer.state.max_steps)
if trainer.state.max_steps
else None,
"num_train_epochs": float(trainer.args.num_train_epochs)
if trainer.args.num_train_epochs
else None,
"train_batch_size": int(trainer.args.train_batch_size)
if hasattr(trainer.args, "train_batch_size")
else None,
"gradient_accumulation_steps": int(
trainer.args.gradient_accumulation_steps
)
if trainer.args.gradient_accumulation_steps
else None,
}
# Remove None values
trainer_config = {
k: v for k, v in trainer_config.items() if v is not None
}
if trainer_config:
swanlab.config.update(trainer_config)
LOG.info("Logged trainer configuration to SwanLab")
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
# Register RLHF completion logging callback if enabled
self._register_completion_callback(cfg, trainer)
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
"""Prepare kwargs for swanlab.init().
Passes all configuration parameters directly to swanlab.init()
instead of using environment variables as an intermediate layer.
Returns:
dict: Keyword arguments for swanlab.init()
"""
init_kwargs = {}
# Project name (required)
if cfg.swanlab_project:
init_kwargs["project"] = cfg.swanlab_project
# Experiment name
if cfg.swanlab_experiment_name:
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
# Description
if cfg.swanlab_description:
init_kwargs["description"] = cfg.swanlab_description
# Workspace (organization)
if cfg.swanlab_workspace:
init_kwargs["workspace"] = cfg.swanlab_workspace
# Mode: cloud, local, offline, disabled
if cfg.swanlab_mode:
init_kwargs["mode"] = cfg.swanlab_mode
# API key (pass directly instead of via env var)
if cfg.swanlab_api_key:
init_kwargs["api_key"] = cfg.swanlab_api_key
# Private deployment hosts (pass directly instead of via env var)
if cfg.swanlab_web_host:
init_kwargs["web_host"] = cfg.swanlab_web_host
if cfg.swanlab_api_host:
init_kwargs["api_host"] = cfg.swanlab_api_host
# Log model checkpoints (coming soon in SwanLab)
if cfg.swanlab_log_model:
init_kwargs["log_model"] = cfg.swanlab_log_model
# Custom branding - adds Axolotl identifier to SwanLab UI
# This helps identify runs from Axolotl vs other frameworks
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
return init_kwargs
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
"""Prepare configuration dict for logging to SwanLab."""
def safe_convert(value):
"""Convert value to JSON-serializable type."""
if value is None:
return None
if isinstance(value, (int, float, bool)):
return value
if isinstance(value, str):
return value
# Convert everything else to string
return str(value)
try:
# Extract important training parameters with safe conversion
config_dict = {
"base_model": safe_convert(getattr(cfg, "base_model", "")),
"model_type": safe_convert(getattr(cfg, "model_type", "")),
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
"micro_batch_size": safe_convert(
getattr(cfg, "micro_batch_size", None)
),
"gradient_accumulation_steps": safe_convert(
getattr(cfg, "gradient_accumulation_steps", None)
),
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
"seed": safe_convert(getattr(cfg, "seed", None)),
"bf16": safe_convert(getattr(cfg, "bf16", None)),
"tf32": safe_convert(getattr(cfg, "tf32", None)),
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
}
# Add FSDP/parallel config - only boolean flags
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
config_dict["fsdp_enabled"] = True
config_dict["fsdp_version"] = safe_convert(
getattr(cfg, "fsdp_version", None)
)
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
config_dict["deepspeed_enabled"] = True
# Add context parallel info
if hasattr(cfg, "context_parallel_size"):
config_dict["context_parallel_size"] = safe_convert(
getattr(cfg, "context_parallel_size", None)
)
if hasattr(cfg, "tensor_parallel_size"):
config_dict["tensor_parallel_size"] = safe_convert(
getattr(cfg, "tensor_parallel_size", None)
)
if hasattr(cfg, "dp_shard_size"):
config_dict["dp_shard_size"] = safe_convert(
getattr(cfg, "dp_shard_size", None)
)
# Remove None values and empty strings
config_dict = {
k: v
for k, v in config_dict.items()
if v is not None and v != "" and v != "None"
}
return config_dict
except Exception as err: # pylint: disable=broad-except
LOG.warning(f"Failed to prepare config for logging: {err}")
# Return minimal config
try:
lr = getattr(cfg, "learning_rate", None)
lr_value = float(lr) if lr is not None else None
except (TypeError, ValueError):
lr_value = None
return {
"base_model": str(getattr(cfg, "base_model", "unknown")),
"learning_rate": lr_value,
}
def _register_lark_callback(self, cfg: DictDefault):
"""Register Lark (Feishu) notification callback if configured.
Lark notifications enable sending training updates to team chat channels,
useful for production monitoring and team collaboration.
Args:
cfg: Configuration object with Lark webhook settings
"""
# Check if Lark webhook URL is configured
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
if not lark_webhook_url:
return # Lark not configured, skip
try:
import swanlab
from swanlab.plugin.notification import LarkCallback
# Get optional secret for HMAC signature authentication
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
# Create Lark callback with webhook URL and optional secret
lark_callback = LarkCallback(
webhook_url=lark_webhook_url,
secret=lark_secret,
)
# Register callback with SwanLab
swanlab.register_callbacks([lark_callback])
if lark_secret:
LOG.info(
"Registered Lark notification callback with HMAC authentication"
)
else:
LOG.info("Registered Lark notification callback (no HMAC secret)")
LOG.warning(
"Lark webhook has no secret configured. "
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab Lark plugin: {err}\n\n"
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
"Install with: pip install 'swanlab>=0.3.0'\n\n"
"Continuing without Lark notifications..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register Lark callback: %s\n\n"
"Check your Lark webhook URL and secret configuration.\n"
"Continuing without Lark notifications...",
err,
)
def _register_completion_callback(self, cfg: DictDefault, trainer):
"""Register RLHF completion logging callback if enabled and applicable.
This callback logs model completions (prompts, chosen/rejected responses,
rewards) to SwanLab during RLHF training for qualitative analysis.
Args:
cfg: Configuration object with completion logging settings
trainer: The trainer instance to add callback to
"""
# Check if completion logging is enabled
log_completions = getattr(cfg, "swanlab_log_completions", True)
if not log_completions:
LOG.debug("SwanLab completion logging disabled by config")
return
# Check if trainer is an RLHF trainer
trainer_name = trainer.__class__.__name__
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
if not is_rlhf_trainer:
LOG.debug(
f"Trainer {trainer_name} is not an RLHF trainer, "
"skipping completion logging callback"
)
return
try:
from axolotl.integrations.swanlab.callbacks import (
SwanLabRLHFCompletionCallback,
)
# Get configuration parameters
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
# Create and register callback
completion_callback = SwanLabRLHFCompletionCallback(
log_interval=log_interval,
max_completions=max_buffer,
table_name="rlhf_completions",
)
trainer.add_callback(completion_callback)
LOG.info(
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
f"(log_interval={log_interval}, max_buffer={max_buffer})"
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab completion callback: {err}\n\n"
"This is a bug - the callback should be available.\n"
"Please report this issue.\n\n"
"Continuing without completion logging..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register SwanLab completion callback: %s\n\n"
"Continuing without completion logging...",
err,
)

View File

@@ -0,0 +1,203 @@
"""SwanLab profiling utilities for Axolotl trainers.
This module provides decorators and context managers for profiling
trainer methods and logging execution times to SwanLab.
"""
import time
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@contextmanager
def swanlab_profiling_context(trainer: Any, func_name: str):
"""Context manager for profiling trainer methods.
Measures execution time and logs to SwanLab if enabled.
Example usage:
>>> with swanlab_profiling_context(self, "training_step"):
... result = do_expensive_computation()
Args:
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
func_name: Name of the function being profiled
Yields:
None
"""
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if SwanLab is enabled and initialized
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
# Log profiling metric
trainer_class = trainer.__class__.__name__
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
swanlab.log({metric_name: duration})
except ImportError:
# SwanLab not installed, silently skip
pass
except Exception as err: # pylint: disable=broad-except
# Log error but don't fail training
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
def swanlab_profile(func: Callable) -> Callable:
"""Decorator to profile and log function execution time to SwanLab.
Automatically measures execution time of trainer methods and logs
to SwanLab as profiling metrics.
Example usage:
>>> class MyTrainer:
... @swanlab_profile
... def training_step(self, model, inputs):
... return super().training_step(model, inputs)
Args:
func: Function to profile (must be a method of a trainer instance)
Returns:
Wrapped function with profiling
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
with swanlab_profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper
class ProfilingConfig:
"""Configuration for SwanLab profiling.
This class provides a centralized way to control profiling behavior.
Attributes:
enabled: Whether profiling is enabled globally
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
log_interval: Log every N function calls (to reduce overhead)
"""
def __init__(
self,
enabled: bool = True,
min_duration_ms: float = 0.1,
log_interval: int = 1,
):
"""Initialize profiling configuration.
Args:
enabled: Enable profiling. Default: True
min_duration_ms: Minimum duration to log (ms). Default: 0.1
log_interval: Log every N calls. Default: 1 (log all)
"""
self.enabled = enabled
self.min_duration_ms = min_duration_ms
self.log_interval = log_interval
self._call_counts: dict[str, int] = {}
def should_log(self, func_name: str, duration_seconds: float) -> bool:
"""Check if a profiling measurement should be logged.
Args:
func_name: Name of the profiled function
duration_seconds: Execution duration in seconds
Returns:
True if should log, False otherwise
"""
if not self.enabled:
return False
# Check minimum duration threshold
duration_ms = duration_seconds * 1000
if duration_ms < self.min_duration_ms:
return False
# Check log interval
self._call_counts.setdefault(func_name, 0)
self._call_counts[func_name] += 1
# Always log on first call OR at intervals
count = self._call_counts[func_name]
if count == 1 or count % self.log_interval == 0:
return True
return False
# Global profiling config (can be modified by users)
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
@contextmanager
def swanlab_profiling_context_advanced(
trainer: Any,
func_name: str,
config: ProfilingConfig | None = None,
):
"""Advanced profiling context with configurable behavior.
Similar to swanlab_profiling_context but with additional configuration
options for filtering and throttling profiling logs.
Example usage:
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
>>> with swanlab_profiling_context_advanced(self, "forward", config):
... output = model(inputs)
Args:
trainer: Trainer instance
func_name: Function name
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
Yields:
None
"""
if config is None:
config = DEFAULT_PROFILING_CONFIG
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if should log based on config
if config.should_log(func_name, duration):
# Check if SwanLab is enabled
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
trainer_class = trainer.__class__.__name__
metric_name = (
f"profiling/Time taken: {trainer_class}.{func_name}"
)
swanlab.log({metric_name: duration})
except ImportError:
pass
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")

View File

@@ -26,6 +26,48 @@ PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager: class PatchManager:
"""Manages the application of patches during the model loading process.""" """Manages the application of patches during the model loading process."""
@staticmethod
def apply_pre_config_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before config loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoConfig.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
if (
hasattr(cfg, "base_model_config")
and cfg.base_model_config
and "kimi-linear" in cfg.base_model_config.lower()
):
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_config,
)
patch_kimi_config()
@staticmethod
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before tokenizer loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoTokenizer.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
if (
hasattr(cfg, "tokenizer_config")
and cfg.tokenizer_config
and "kimi-linear" in cfg.tokenizer_config.lower()
):
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_tokenizer,
)
patch_kimi_tokenizer()
def __init__( def __init__(
self, self,
cfg: DictDefault, cfg: DictDefault,
@@ -96,6 +138,7 @@ class PatchManager:
self._apply_llama_flash_attn_patches(model) self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model) self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model) self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_flash_attention_patches(self): def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention.""" """Apply patches related to Flash Attention."""
@@ -157,12 +200,6 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs) patch_flex_wrapper(**flex_attn_compile_kwargs)
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
)
patch_create_causal_mask(self.cfg.model_config_type)
def _apply_model_specific_patches(self): def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures.""" """Apply patches specific to model architectures."""
@@ -190,6 +227,13 @@ class PatchManager:
apply_mistral_tokenizer_image_patch() apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
)
patch_kimi_model()
def _apply_fp8_patches(self): def _apply_fp8_patches(self):
"""Apply patches for FP8 support.""" """Apply patches for FP8 support."""
if self.cfg.fp8: if self.cfg.fp8:
@@ -517,3 +561,16 @@ class PatchManager:
) )
patch_apertus_xielu_activation() patch_apertus_xielu_activation()
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:
from axolotl.monkeypatch.scaled_softmax_attn import (
patch_scaled_softmax_attention,
)
patch_scaled_softmax_attention(
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
bias=self.cfg.scaling_softmax_bias or 0.0,
model=model,
)

View File

@@ -124,6 +124,11 @@ def modify_tokenizer_files(
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config.""" """Load and configure the tokenizer based on the provided config."""
# Apply patches that need to be in place before tokenizer loading
from axolotl.loaders.patch_manager import PatchManager
PatchManager.apply_pre_tokenizer_load_patches(cfg)
def _load_mistral_common_tokenizer(cfg: DictDefault): def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer""" """Load mistral-common tokenizer"""
from axolotl.utils.mistral import HFMistralTokenizer from axolotl.utils.mistral import HFMistralTokenizer

View File

@@ -5,6 +5,7 @@ from typing import Type
import addict import addict
import torch import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -79,7 +80,11 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
and hasattr(model_config, "vision_config") and hasattr(model_config, "vision_config")
and hasattr(model_config.vision_config, "image_size") and hasattr(model_config.vision_config, "image_size")
): ):
cfg.image_size = model_config.vision_config.image_size image_size = model_config.vision_config.image_size
if isinstance(image_size, list):
cfg.image_size = tuple(image_size)
else:
cfg.image_size = image_size
LOG.debug(f"Loaded image size: {cfg.image_size} from model config") LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
quant_config_exists = ( quant_config_exists = (
@@ -149,6 +154,9 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config. necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -170,8 +178,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels: if cfg.num_labels:
# num_labels is used to initialize classifier models # num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try: try:
model_config = AutoConfig.from_pretrained( model_config = config_cls.from_pretrained(
model_config_name, model_config_name,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**config_kwargs, **config_kwargs,

View File

@@ -75,3 +75,33 @@ def patch_parallelism_config():
ParallelismConfig._validate_accelerator = _validate_accelerator ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2) AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
def patch_prepare_cp():
import functools
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -0,0 +1,98 @@
"""Dynamic Fine-Tuning (DFT) loss implementation"""
from typing import Optional
import torch
import torch.nn.functional as F
def selective_log_softmax(logits, index):
"""Memory-efficient log_softmax -> gather"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(
logits, dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(
dim=-1, index=row_labels.unsqueeze(-1)
).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def get_dft_loss(ignore_index: int = -100):
"""Creates DFT loss function"""
def for_causal_lm_dft_loss(
logits,
labels,
vocab_size: int = None,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""DFT loss: -exp(logprobs).detach() * logprobs"""
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss
return for_causal_lm_dft_loss
def dft_loss(outputs, labels, num_items_in_batch=None):
"""DFT loss compatible with Trainer.compute_loss_func signature.
This function is designed to be passed to Trainer's compute_loss_func parameter.
"""
ignore_index = -100
# Shift labels for causal LM
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(outputs.logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss

View File

@@ -0,0 +1,148 @@
"""
Kimi-Linear configuration.
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py
Revision: 6e163f3
"""
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class KimiLinearConfig(PretrainedConfig):
model_type = "kimi_linear"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
model_type="kimi_linear",
vocab_size=163840,
hidden_size=4096,
head_dim=None,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
rope_theta=10000.0,
rope_scaling=None,
tie_word_embeddings=False,
moe_intermediate_size: Optional[int] = None,
moe_renormalize: bool = True,
moe_router_activation_func: str = "sigmoid",
num_experts: Optional[int] = None,
num_experts_per_token: Optional[int] = None,
num_shared_experts: int = 0,
routed_scaling_factor: float = 1.0,
first_k_dense_replace: int = 0,
moe_layer_freq: int = 1,
use_grouped_topk: bool = True,
num_expert_group: int = 1,
topk_group: int = 1,
q_lora_rank: Optional[int] = None,
kv_lora_rank: Optional[int] = None,
qk_nope_head_dim: Optional[int] = None,
qk_rope_head_dim: Optional[int] = None,
v_head_dim: Optional[int] = None,
mla_use_nope: Optional[bool] = False,
num_nextn_predict_layers: int = 0,
linear_attn_config: Optional[dict] = None,
router_aux_loss_coef: float = 0.01,
**kwargs,
):
self.model_type = model_type
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.head_dim = (
head_dim if head_dim is not None else hidden_size // num_attention_heads
)
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.mla_use_nope = mla_use_nope
# moe config
self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
self.moe_renormalize = moe_renormalize
self.num_shared_experts = num_shared_experts
self.routed_scaling_factor = routed_scaling_factor
self.moe_router_activation_func = moe_router_activation_func
assert self.moe_router_activation_func in ("softmax", "sigmoid")
self.moe_intermediate_size = moe_intermediate_size
self.first_k_dense_replace = first_k_dense_replace
self.moe_layer_freq = moe_layer_freq
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.num_nextn_predict_layers = num_nextn_predict_layers
self.router_aux_loss_coef = router_aux_loss_coef
if linear_attn_config is not None:
assert linear_attn_config["kda_layers"] is not None
assert linear_attn_config["full_attn_layers"] is not None
self.linear_attn_config = linear_attn_config
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def is_mla(self):
return (
self.q_lora_rank is not None
or self.kv_lora_rank is not None
or self.qk_nope_head_dim is not None
or self.qk_rope_head_dim is not None
or self.v_head_dim is not None
or self.mla_use_nope is True
)
@property
def is_moe(self):
return self.num_experts is not None
@property
def is_linear_attn(self) -> bool:
return not (
self.linear_attn_config is None
or (
isinstance(self.linear_attn_config, dict)
and self.linear_attn_config["kda_layers"] is not None
and len(self.linear_attn_config["kda_layers"]) == 0
)
)
def is_kda_layer(self, layer_idx: int):
return (
self.linear_attn_config is not None
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
import importlib.resources
import importlib.util
import sys
from pathlib import Path
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
"""
Gets the absolute path to a patch file using importlib.resources.files.
"""
try:
return importlib.resources.files(package_dot_path) / filename
except ModuleNotFoundError:
return None
def _load_local_module(module_name: str, filename: str):
"""Helper to load a local module if not already loaded."""
if module_name in sys.modules:
return sys.modules[module_name]
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
if patch_path and patch_path.exists():
spec = importlib.util.spec_from_file_location(module_name, patch_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
return None
def _patch_get_class_in_module():
"""
Core patch function that hijacks Transformers' dynamic module loading.
"""
from transformers.dynamic_module_utils import get_class_in_module
if hasattr(get_class_in_module, "_axolotl_patched"):
return
original_get_class_in_module = get_class_in_module
# Mapping of module path patterns to (module_name, filename)
KIMI_MODULE_MAP = {
"configuration_kimi": ("configuration_kimi", "configuration_kimi.py"),
"modeling_kimi": ("modeling_kimi", "modeling_kimi.py"),
"tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"),
}
def patched_get_class_in_module(class_name, module_path, **kwargs):
"""Patched version that returns our local modules instead of remote ones."""
for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():
if pattern in module_path:
module = _load_local_module(module_name, filename)
if module:
return getattr(module, class_name)
break # Pattern matched but file not found, fall through
return original_get_class_in_module(class_name, module_path, **kwargs)
import transformers.dynamic_module_utils
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
patched_get_class_in_module._axolotl_patched = True
def patch_kimi():
"""
Apply all Kimi patches.
Must be called BEFORE loading config/tokenizer/model.
"""
_patch_get_class_in_module()
LOG.info("Kimi patches applied successfully!")
# Keep these for backward compatibility if needed
patch_kimi_config = patch_kimi
patch_kimi_tokenizer = patch_kimi
patch_kimi_model = patch_kimi

View File

@@ -0,0 +1,357 @@
"""
Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py
Revision: 919416f
"""
import os
from logging import getLogger
from pathlib import Path
from shutil import copyfile
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
cast,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from tokenizers import AddedToken
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
from transformers.tokenization_utils import PreTrainedTokenizer
logger = getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
class TikTokenTokenizer(PreTrainedTokenizer):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
The path to the Tiktoken model file.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
The end of sequence token.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead. The second to last item in special_tokens.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (list of `str`, *optional*):
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
skipped when decoding if `skip_special_tokens` is set to `True`.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = "|".join(
[
r"""[\p{Han}]+""",
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
r"""\p{N}{1,3}""",
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
r"""\s*[\r\n]+""",
r"""\s+(?!\S)""",
r"""\s+""",
]
)
def __init__(
self,
vocab_file,
bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107
eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107
unk_token: Union[str, AddedToken, None] = None,
pad_token: Union[str, AddedToken, None] = None,
additional_special_tokens: List[str] = None,
added_tokens_decoder: Optional[dict] = None,
**kwargs,
):
assert os.path.isfile(vocab_file), vocab_file
if additional_special_tokens is None:
additional_special_tokens = [
"<|im_end|>",
"<|im_user|>",
"<|im_assistant|>",
"<|start_header_id|>",
"<|end_header_id|>",
"[EOT]",
"<|im_system|>",
"<|im_middle|>",
]
special_tokens_mapping = {
i: added_tokens_decoder[i].content for i in added_tokens_decoder
}
self.vocab_file = vocab_file
mergeable_ranks = load_tiktoken_bpe(vocab_file)
num_base_tokens = len(mergeable_ranks)
self.special_tokens = {
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
for i in range(
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
)
}
self.model = tiktoken.Encoding(
name=Path(vocab_file).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
logger.info(f"Reloaded tiktoken model from {vocab_file}")
self.n_words: int = self.model.n_vocab
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens[str(bos_token)]
self.eos_id: int = self.special_tokens[str(eos_token)]
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
self.pad_id: int = self.special_tokens[str(pad_token)]
self.unk_id: int = self.special_tokens[str(unk_token)]
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.decoder = {}
for i in range(self.n_words):
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
decoding = "".join(
[
self.byte_encoder[ord(char)]
for char in self.model.decode_single_token_bytes(i).decode(
"latin-1"
)
]
)
self.decoder[i] = decoding
self.encoder = {}
for i in range(self.n_words):
if i in self.decoder:
self.encoder[self.decoder[i]] = i
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.all_special_ids_set = set(self.all_special_ids)
def encode(
self, text: str, allow_special_tokens: bool = True, **kwargs
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
text (str): The input string to be encoded.
Returns:
list[int]: A list of token IDs.
"""
# If there are other args, we should call super().encode because there are a lot of code
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
# NOTE: our encode method is not compatible with the super().encode method,
# e.g. split_special_tokens' default is True in our encode method.
if len(kwargs) > 0:
# logger.warning(f"Calling super().encode with {kwargs}")
return super().encode(text, **kwargs)
assert type(text) is str
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
texts = self.pre_tokenizer_process(text)
all_substrs = []
for text in texts:
substrs = (
substr
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
all_substrs.extend(substrs)
t: List[int] = []
for substr in all_substrs:
if allow_special_tokens:
t.extend(
# we should consider special token as a common token
self.model.encode(
substr,
allowed_special="all",
)
)
else:
t.extend(
# we should consider special token as a common token
self.model.encode(
substr,
disallowed_special=(),
)
)
return t
def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
"""
Decodes a list of token IDs into a string.
Args:
token_ids (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# If there are other args, we should call super().decode because there are a lot of code
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
if len(kwargs) > 0:
return super().decode(token_ids, **kwargs)
if type(token_ids) is int:
token_ids = [token_ids]
return self.model.decode(cast(List[int], token_ids))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
def pre_tokenizer_process(self, text: str) -> List[str]:
"""
pre-tokenizes the input text into a list of tokens.
This method is used to split the input text into smaller chunks for internal processing.
"""
return [text]
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
@property
def vocab_size(self) -> int:
return self.n_words
def get_vocab(self) -> Dict[str, int]:
return self.encoder
def _tokenize(self, text: str, **kwargs) -> List[str]:
return [self.decoder[t] for t in self.encode(text)]
def _convert_token_to_id(self, token: str) -> int:
return self.encoder.get(token, self.unk_id)
def _convert_id_to_token(self, index: int) -> str:
return self.decoder.get(index)
@staticmethod
def clean_up_tokenization(out_string: str) -> str:
return out_string
def convert_tokens_to_string(self, tokens: List[str]) -> str:
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", "replace"
)
return text
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
raise ValueError(
f"vocabulary path ({save_directory}) should be a directory"
)
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
def apply_chat_template(
self,
conversation,
tools: Optional[list[dict]] = None,
tokenize: bool = True,
add_generation_prompt: bool = False,
**kwargs,
):
tools = deep_sort_dict(tools)
return super().apply_chat_template(
conversation,
tools=tools,
tokenize=tokenize,
add_generation_prompt=add_generation_prompt,
**kwargs,
)
def deep_sort_dict(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
if isinstance(obj, list):
return [deep_sort_dict(item) for item in obj]
return obj

View File

@@ -0,0 +1,141 @@
"""
Scaled Softmax (SSMax) attention patch using FlexAttention.
SSMax: softmax(scores * s * log(n) + b) where n is the position index
Ref: https://arxiv.org/abs/2501.19399
"""
import torch
from transformers import PreTrainedModel
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from torch.nn.attention.flex_attention import BlockMask
from transformers.integrations.flex_attention import (
compile_friendly_flex_attention,
repeat_kv,
)
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
_ssmax_config = {}
def patch_scaled_softmax_attention(
scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None
):
"""Patch attention to apply SSMax via FlexAttention score_mod."""
global _ssmax_config
if not FLEX_ATTENTION_AVAILABLE:
raise RuntimeError("SSMax requires FlexAttention.")
_ssmax_config["ssmax_s"] = scaling_factor_init
_ssmax_config["ssmax_b"] = bias
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "flex_attention" in ALL_ATTENTION_FUNCTIONS:
_ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"]
ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward
LOG.info(
f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})"
)
else:
LOG.warning("flex_attention not found. Ensure flex_attention: true is set.")
def ssmax_flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask,
scaling: float | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""FlexAttention forward with SSMax: score * (s * log(n) + b)."""
if kwargs.get("dropout", 0.0) > 0:
raise ValueError("flex_attention does not support dropout")
ssmax_s = _ssmax_config.get("ssmax_s", 0.43)
ssmax_b = _ssmax_config.get("ssmax_b", 0.0)
position_ids = kwargs.get("position_ids", None)
position_ids_flat = position_ids.view(-1) if position_ids is not None else None
block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None
score_mask = None if block_mask else attention_mask
if score_mask is not None:
score_mask = score_mask[:, :, :, : key.shape[-2]]
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
"""
Apply SSMax scaling: score * (s * log(n) + b)
where n is the relative position within each packed sequence.
"""
if position_ids_flat is not None:
relative_pos = position_ids_flat[q_idx]
n = (relative_pos + 1).float()
else:
n = (q_idx + 1).float()
n = torch.clamp(n, min=2.0)
ssmax_scale = ssmax_s * torch.log(n) + ssmax_b
score = score * ssmax_scale
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if score_mask is not None:
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
return score
enable_gqa = True
if (query.shape[1] & (query.shape[1] - 1)) != 0:
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
return_lse = query.device.type != "cpu"
flex_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kwargs.get("kernel_options"),
return_lse=return_lse,
training=module.training,
)
if return_lse:
attention_output, lse = flex_output
lse = lse.to(value.dtype)
else:
attention_output, lse = flex_output, None
return attention_output.transpose(1, 2).contiguous(), lse
def unpatch_scaled_softmax_attention():
"""Restore the original FlexAttention function."""
global _ssmax_config
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "original_flex_fn" in _ssmax_config:
ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"]
_ssmax_config.clear()
LOG.info("Unpatched flex_attention, restored original")

View File

@@ -8,6 +8,7 @@ from PIL.Image import Resampling
from torch import Tensor, zeros_like from torch import Tensor, zeros_like
from transformers import ProcessorMixin from transformers import ProcessorMixin
from transformers.image_utils import load_image from transformers.image_utils import load_image
from transformers.models.internvl import InternVLProcessor
from transformers.models.smolvlm import SmolVLMProcessor from transformers.models.smolvlm import SmolVLMProcessor
from transformers.models.voxtral import VoxtralProcessor from transformers.models.voxtral import VoxtralProcessor
@@ -454,6 +455,37 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
return labels return labels
class InternVLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for InternVL"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
if not hasattr(processor, "image_ids"):
raise ValueError("'image_ids' missing from InternVL Processor.")
self.image_token_ids = processor.image_ids
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.processor.tokenizer.pad_token_id] = -100
for ids in self.image_token_ids:
labels[labels == ids] = -100
# Note: Check if need to mask 'video_token' as it gets converted to
# image patches during media processing
return labels
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -501,6 +533,11 @@ def get_processing_strategy(
**processing_kwargs, **processing_kwargs,
) )
if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava # llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl # mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy( return ProcessingStrategy(

View File

@@ -24,6 +24,10 @@ def is_opentelemetry_available():
) )
def is_trackio_available():
return importlib.util.find_spec("trackio") is not None
def get_pytorch_version() -> tuple[int, int, int]: def get_pytorch_version() -> tuple[int, int, int]:
""" """
Get Pytorch version as a tuple of (major, minor, patch). Get Pytorch version as a tuple of (major, minor, patch).

View File

@@ -0,0 +1,248 @@
"""Callbacks for SwanLab integration"""
from __future__ import annotations
import json
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class CustomSwanLabCallback(TrainerCallback):
"""
Lightweight SwanLab callback that directly logs metrics without using
SwanLab's transformers integration (which requires omegaconf).
This avoids the antlr4 version conflict between omegaconf and axolotl.
"""
def __init__(self):
self._initialized = False
self.swanlab = None
def setup(self):
"""Lazy initialization of SwanLab"""
if self._initialized:
return
try:
import swanlab
self.swanlab = swanlab
# Check if SwanLab run is initialized
if swanlab.get_run() is None:
LOG.warning("SwanLab run is not initialized")
return
self._initialized = True
LOG.info("CustomSwanLabCallback initialized successfully")
except ImportError:
LOG.error("SwanLab is not installed")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the beginning of training"""
if not state.is_world_process_zero:
return control
self.setup()
if not self._initialized:
return control
# Log training configuration
try:
self.swanlab.config.update(
{
"train_batch_size": args.per_device_train_batch_size,
"eval_batch_size": args.per_device_eval_batch_size,
"learning_rate": args.learning_rate,
"num_train_epochs": args.num_train_epochs,
"max_steps": args.max_steps,
"warmup_steps": args.warmup_steps,
"logging_steps": args.logging_steps,
"save_steps": args.save_steps,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
}
)
LOG.debug("Training configuration logged to SwanLab")
except Exception as err:
LOG.warning(f"Failed to log training config: {err}")
return control
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
):
"""Called when logging metrics"""
if not state.is_world_process_zero:
return control
if not self._initialized:
self.setup()
if not self._initialized or logs is None:
return control
# Log metrics to SwanLab
try:
# Filter out non-numeric values and prepare for logging
metrics = {}
for key, value in logs.items():
if isinstance(value, (int, float)):
# Use step from state
metrics[key] = value
if metrics and state.global_step is not None:
self.swanlab.log(metrics, step=state.global_step)
except Exception as err:
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
return control
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of training"""
if not state.is_world_process_zero:
return control
if self._initialized:
LOG.info("Training completed. SwanLab logs are available.")
return control
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
"""Callback to save axolotl config to SwanLab"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
try:
import swanlab
# Check if SwanLab is initialized
if swanlab.get_run() is None:
LOG.warning(
"SwanLab run is not initialized. Please initialize SwanLab before training."
)
return control
# Log Axolotl config as artifact
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
# Log config file to SwanLab
with open(temp_file.name, "r", encoding="utf-8") as config_file:
swanlab.log(
{
"axolotl_config": swanlab.Text(
config_file.read(), caption="Axolotl Config"
)
}
)
LOG.info(
"The Axolotl config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except ImportError:
LOG.warning(
"SwanLab is not installed. Install it with: pip install swanlab"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
# Log DeepSpeed config if available
if args.deepspeed:
try:
import swanlab
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
temp_file.flush()
with open(
temp_file.name, "r", encoding="utf-8"
) as ds_config_file:
swanlab.log(
{
"deepspeed_config": swanlab.Text(
ds_config_file.read(),
caption="DeepSpeed Config",
)
}
)
LOG.info(
"The DeepSpeed config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(
f"Error while saving DeepSpeed config to SwanLab: {err}"
)
except ImportError:
pass
return control

Some files were not shown because too many files have changed in this diff Show More