Compare commits

..

59 Commits

Author SHA1 Message Date
Wing Lian
9c0fa60220 fsdp2 w evals fixed upstream 2025-08-11 16:26:42 -04:00
Wing Lian
8efdc59796 just assume that fa supports window 2025-08-11 16:09:11 -04:00
Wing Lian
172b08b209 integration check for transformers#40002 2025-08-11 10:06:11 -04:00
Wing Lian
3d45620008 remove prepare-from-posids patch (#3052) [skip ci] 2025-08-11 09:34:41 -04:00
github-actions[bot]
ce20e838b5 chore: update pre-commit hooks (#3050) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-08-11 09:32:21 -04:00
Wing Lian
d4d84d48af fix ray train and add fsdp2 smoke test for ray trainer (#3053)
* add fsdp2 smokle test for ray trainer

* fix raytrain with fsdp2
2025-08-11 09:31:54 -04:00
Wing Lian
9b12c05660 use exec instead of subprocess to make ctrl+c nicer for cli (#3044)
* use exec instead of subprocess to make ctrl+c nicer for cli

* change var name to use_exec

* simplify to bool

* flush std*

* patch subprocess as mock in test

* fix tests

* more test fixes
2025-08-10 20:22:20 -04:00
Wing Lian
686933194e fix vllm tagging and add cloud images w/o tmux (#3049) [skip ci] 2025-08-10 20:21:56 -04:00
Wing Lian
d12b461d19 follow up fix for plugin registration (#3054) [skip ci] 2025-08-10 20:21:38 -04:00
Wing Lian
d6b81b3683 update training args check for new defaults (#3051) [skip ci]
* update training args check for new defaults

* skip check for now
2025-08-10 11:26:22 -04:00
Wing Lian
05f1b4b2e8 run monkeypatch tests in seperate runner (#3047) 2025-08-09 14:34:07 -04:00
Wing Lian
7cfc80ec77 set dev version (#3045) [skip ci] 2025-08-08 13:56:53 -04:00
salman
0da6a95efa Add citation.tff (#3043) [skip ci] 2025-08-08 16:18:42 +01:00
Wing Lian
2c8497e489 tag for v0.12.0 release (#3041)
Some checks failed
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, true, 3.11, 2.7.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, true, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-08-08 08:24:09 -04:00
NanoCode012
f70d4de8c7 feat(doc): add links to new features on README (#2980) [skip ci]
* feat(doc): add links to new features on README

* fix merge error

* remove blurb about older FSDP2 integration

* update blog link

* chore: update cce commit

* feat: update model support into readme

* Update README.md

Co-authored-by: salman <salman.mohammadi@outlook.com>

* chore: lint num spaces

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-08 08:16:43 -04:00
Dan Saunders
0ae06d756d use nanmean for loss aggregation (CP fix) (#3033)
* use nanmena for loss aggregation (CP fix)

* use regular asserts

* small changes to make tests isolate

* combining evaluation_loop patches

* fix

* delete unused

* fix check
2025-08-08 08:15:17 -04:00
NanoCode012
2974670bf8 Feat: add arcee (#3028)
* feat: add arcee

* feat: add latest models supported by cce

* feat: add arcee example config

* chore: lint

* fix: typo

* feat: change to instruct

* feat: add vram usage

* Update README.md
2025-08-08 08:09:11 -04:00
Wing Lian
50f2b94d50 add 120b and deepspeed zero3 examples (#3035) [skip ci]
* add 120b and deepspeed zero3 examples

* add a bit of flavor and cleanup gpt oss readme

* fix: remove expert vram usage

* fix: remove redundant EOS token from eot_tokens

* feat: add 120B to docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-08 08:04:56 -04:00
Wing Lian
eb2c87b525 Example for Slurm and various fixes (#3038) [skip ci]
* slurm example and make preprocess play nicely

* start slurm if it init file exists

* remove incorrect comment

* feat: add slurm docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-08 08:02:03 -04:00
NanoCode012
4db7f023c6 feat(doc): standardize the axolotl install to a release (#3040) [skip ci] 2025-08-08 08:00:26 -04:00
NanoCode012
4273d5cf7e feat: update nd parallelism readme (#3039)
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-08 12:45:36 +01:00
Wing Lian
c5e5aba547 Add 2.8.0 base images and uv images (#3034) 2025-08-08 02:30:16 -04:00
Wing Lian
9d5c95db6f Add support for Accelerate CP, ND examples, and fix for parallel config w fsdp (#3019)
* fix for parallelism config from trainer

* fix handling of parallelism_config w accelerate

* add todo for removal

* update to latest axolotl-contribs-mit for optimizer fix too

* synchronize training after checkpoint save

* dir spelling

* use latest accelerate main

* fix to not use partial state parallelism_config

* more fixeS

* use most recent accelerate fix

* fix cpu_ram_efficient_loading to meta devices from rank 0 to prevent CPU RAM oom

* improve handling of broadcasting fsdp2 state dict

* support for openai chat template with thinking key as the reasoning trace

* address PR feedback

* refactor to remove dependency on PartialState for parallelism config

* bump accelerate, gptoss fixes

* limit meta fixes to fsdp2 for now

* fixes for gpt oss

* fixup examples, don't use cpu-ram-efficient-loading for now

* remove problematic barrier

* patch parallelism config

* reorder comparison

* device mesh fixes

* make pure CP work

* lint
2025-08-07 21:22:15 -04:00
NanoCode012
ca796fb56e feat(doc): update gpt-oss readme (#3029) [skip ci]
* feat(doc): update gpt-oss readme

* fix: caps

* feat: add toolcalling section

* feat: add example tool dataset to docs

* chore: update
2025-08-07 09:26:42 -04:00
VED
597953bef0 clear cache before clean up (#3031) [skip ci]
* clear chahe before save_model

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-07 09:25:58 -04:00
NanoCode012
39fbd3b2b5 fix: lora kernels for mistral3 (#3027) [skip ci] 2025-08-07 09:25:37 -04:00
salman
46dfacf255 ND Parallel Doc Nits (#3032) 2025-08-07 10:34:26 +01:00
Wing Lian
4bce713b39 allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]
* allow custom trainer_cls to be defined as a module reference in the YAML

* address PR feedback and add test

* add tests
2025-08-06 22:49:19 -04:00
Dan Saunders
d09290f2f4 Lora kernels bias support (#3025)
* lora kernels bias support

* revert rename

* nit

* lint, tests

* satisfying the rabbit
2025-08-06 20:20:08 -04:00
Wing Lian
e442ff22aa fix keyerror on load_in_8bit/load_in_4bit access in _set_quantization_config (#3023)
* set load_in_8bit/load_in_4bit in _set_quantization_config to prevent keyerror

* use dict.get instead
2025-08-06 14:28:52 -04:00
Wing Lian
ba3dba3e4f add kernels for gpt oss models (#3020)
* add kernels for gpt oss models

* add support for gpt-oss

* typo incorrect package

* fix: layout for configs and added wandb/epochs

* add gptoss example w offload and set moe leaf for z3

* add support for Mxfp4Config from yaml

* update yaml to use official model

* fix lora and don't allow triton to go above 3.3.1

* fix lr and tweak vram use

* fix range for triton since pinned wasn't compatible with toch 2.6.0

* update cce with gpt oss patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-06 09:47:55 -04:00
Wing Lian
97e86c6d47 drop old patches and code that are no longer needed (#3007) [skip ci] 2025-08-06 08:02:39 -04:00
VED
784f8c0e95 fix:kd_distillation key_error logprobs (#2990)
* fix:kd_distillation key_error logprobs

* style

* fix: leave handling of pop logprobs to parent

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-06 08:02:07 -04:00
NanoCode012
e3177c3210 feat: add complete optimizer docs (#3017) [skip ci]
* feat: add complete optimizer docs

* fix: deprecate old torchao adamw low bit
2025-08-06 08:01:51 -04:00
Wing Lian
70faea331f add support for connecting via prime-intellect (#3021) 2025-08-06 01:06:52 -04:00
Wing Lian
8021c718ce use skip_move_to_device for all cases (#3015)
* use skip_move_to_device for all cases

* use experimental option for skip move
2025-08-06 00:13:12 -04:00
Wing Lian
42f5e6f9e9 upgrade transformers==4.55.0 (#3018) 2025-08-05 16:29:12 -04:00
Wing Lian
ab49d16e34 Dion optimizer support (#3014)
* Add support for Dion optimizer

* dion training kwargs

* fix var names

* no dion 8bit for now

* use updated axolotl-contribs-mit for dion optimizer

* add smoke test for dion optimizer

* add docs

* fix typo during edits

* fix test to not remove load in 8bit
2025-08-04 16:33:30 -04:00
Carsten Kragelund Jørgensen
33d094721c fix: deepcopy lr in RexLR scheduler. (#3012)
* fix: deepcopy lr in RexLR scheduler.

This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate.

See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/

* fix: add missing torch.Tensor import
2025-08-04 10:23:49 -04:00
NanoCode012
a54c1be972 Fix: shorten mem logs to 2 decimal places and renamed nd docs (#3011) [skip ci]
* fix: shorten memory logs

* fix: title name
2025-08-04 10:23:36 -04:00
github-actions[bot]
5691992d34 chore: update pre-commit hooks (#3009) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-08-04 10:23:19 -04:00
Dan Saunders
e758343cac FSDP2 + LoRA kernels (#2992)
* impl fix

* smoke tests

* patches for fsdp2 + qlora compat

* nit

* working fix

* working fix

* fix merge

* minifying patches; update bnb dep

* renaming; adding tests

* remove duplicate test, add dora guard

* generalize __torch_function__

* revert generalization

* update comments
2025-08-03 20:05:17 -04:00
Wing Lian
deac7b18a1 upgrade peft v0.17.0 and support for lora target_parameters (#3006) 2025-08-02 20:24:04 -04:00
Wing Lian
10946afae7 fixes for spinning up vllm service for grpo (#3001) 2025-08-02 11:19:24 -04:00
Wing Lian
5639552064 prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)
* prevent usage of low bit ao optimizers with configurations that use parameter groups

* use optimizer enum value

* fix validation
2025-08-01 17:54:04 -04:00
Wing Lian
cda3c82351 move ib/rdma libs into base image (#3002)
* move ib/rdma libs into base image

* use  --no-install-recommends
2025-08-01 16:10:37 -04:00
Wing Lian
7c3b428f23 Add validation for TP with models with tied embeddings (#2999)
* add validation for tp + tied embeddings models

* fix logic and messaging

* add additional guard for null tp size
2025-08-01 13:58:16 -04:00
Wing Lian
01a6bd1a0e use CCE fix for TP using vocab parallel for CEL (#3000) 2025-08-01 13:21:58 -04:00
NanoCode012
41709822a7 fix: move memory usage log to trainer.log (#2996) [skip ci] 2025-08-01 13:21:43 -04:00
Wing Lian
02a37199ee prevent empty value for vllm_mode (#2998) 2025-08-01 09:59:45 -04:00
NanoCode012
7026cd5e9e Feat: Add N-D parallelism docs (#2989)
* fix: remove non-existent file

* feat: add n-d parallel docs

* fix: comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-01 13:18:31 +07:00
NanoCode012
eb0a8a7775 feat: upgrade cce commit to include smollm3, granite, granitemoe (#2993) 2025-07-31 18:18:44 -04:00
salman
294c7fe7a6 Distributed/ND-Parallel (#2977) 2025-07-31 15:25:02 -04:00
Wing Lian
7b68dfafd7 jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler

var name fix
make sure to create scheduler first

* wire things together

* more fixes

* fix for nesting scheduler and first anneal phase

* no need for relora trainer anymore since we've generalized the relora scheduler

* remove redundant relora scheduler and lint

* update relora e2e test for updated params

* need restart steps for relora test

* update quarto docs for dropped relora trainer

* update example yaml

* drop verbose arg

* min lr scale support for jagged lr

* don't let min_lr be nonetype

* cleanup args
2025-07-31 13:50:03 -04:00
salman
32a7890231 Revert test update to index.qmd (#2995) [skip ci] 2025-07-31 11:46:31 -04:00
Wing Lian
563f5eed7a update dependencies - liger + trl (#2987)
* update dependencies

* set dataset processes for tests

* add support for GSPO
2025-07-31 11:17:17 -04:00
Wing Lian
6ec282094d actually call the register method on plugins (#2991) [skip ci] 2025-07-31 11:13:15 -04:00
salman
09dda462ab Fix don't preview docs for contributors (#2994) [skip ci]
* checking against fork vs. main repo

* force doc preview
2025-07-31 11:12:41 -04:00
Dan Saunders
bb1cae1a20 CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
2025-07-30 15:46:56 -04:00
172 changed files with 4982 additions and 3490 deletions

View File

@@ -54,7 +54,7 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.6.3
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
@@ -64,9 +64,16 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base-nightly"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: nightly
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
@@ -122,6 +129,13 @@ jobs:
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -129,6 +143,13 @@ jobs:
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -24,12 +24,13 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -97,6 +98,12 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -150,6 +157,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -53,7 +53,7 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
id: netlify
with:
publish-dir: './_site'
@@ -68,7 +68,7 @@ jobs:
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
if: ${{ steps.netlify.outcome == 'success' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -105,7 +105,8 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
@@ -179,8 +180,8 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -23,11 +23,11 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.7
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.0
rev: v1.17.1
hooks:
- id: mypy
additional_dependencies:

View File

@@ -185,7 +185,6 @@ datasets:
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |

View File

@@ -296,7 +296,6 @@
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}

10
CITATION.cff Normal file
View File

@@ -0,0 +1,10 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
url: "https://axolotl.ai/"
license: Apache-2.0
date-released: "2023-05-30"

View File

@@ -25,17 +25,28 @@
## 🎉 Latest Updates
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
</details>
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -138,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
```
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

View File

@@ -35,25 +35,30 @@ quartodoc:
- cli.train
- cli.evaluate
- cli.args
- cli.art
- cli.checks
- cli.config
- cli.delinearize_llama4
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.quantize
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- cli.utils
- cli.utils.args
- cli.utils.fetch
- cli.utils.load
- cli.utils.sweeps
- cli.utils.train
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.mamba
- core.trainers.relora
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
@@ -269,7 +274,7 @@ website:
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/gradient_accumulation.qmd
- docs/optimizers.qmd
- section: "Advanced Features"
contents:
@@ -279,6 +284,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -16,7 +16,10 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
&& apt-get install -y --no-install-recommends \
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
ibverbs-providers ibverbs-utils infiniband-diags \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \

View File

@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \

View File

@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Launcher Arguments
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
```bash
# Pass torchrun arguments
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
# Pass accelerate arguments
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
```
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
## Command Reference
### fetch
@@ -80,7 +94,11 @@ axolotl train config.yml \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
axolotl train config.yml --launcher python
# Pass launcher-specific arguments using -- separator
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
axolotl evaluate config.yml
# Evaluation with launcher arguments
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
```
### lm-eval
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```

View File

@@ -212,10 +212,11 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
Example config for Llama4:
```yaml
chat_template: llama4
datasets:
- path: ...
- path: Nanobit/text-tools-2k-test
type: chat_template
# field_tools: tools # default is `tools`
```

View File

@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
Run the following on each node:
### Option 1: New Axolotl CLI with launcher args (Recommended)
```bash
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
```
### Option 2: Direct torchrun (Legacy)
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
Please make sure to substitute the placeholder variables:
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

108
docs/nd_parallelism.qmd Normal file
View File

@@ -0,0 +1,108 @@
---
title: "N-D Parallelism (Beta)"
---
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
- A model's weights are too large to fit on a single GPU's memory.
- A model's activations, especially with very long contexts, are too large for a single GPU.
- You want to accelerate training by using multiple GPUs or nodes.
or combinations of the above!
## Core Concepts
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
### Data Parallelism {#sec-dp}
Data Parallelism focuses on splitting the global data batch across GPUs.
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
- FSDP maps to ZeRO stages:
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
### Context Parallelism (CP) {#sec-cp}
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
## Usage
```yaml
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
fsdp_version: 2
fsdp_config:
# ...
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 4
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 2
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
```
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
## Examples
::: {.callout-tip}
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
:::
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
2. FSDP + TP on a single 8-GPU node:
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
3. FSDP + CP on a single 8-GPU node for long context:
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
## Support Matrix
This matrix describes how different parallelism methods can be combined in Axolotl.
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
| --- | :---: | :---: |:---:|:---:|---|
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`
- `cp_size` refers to `context_parallel_size`

129
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,129 @@
---
title: Optimizers
description: Configuring optimizers
---
## Overview
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
Here is a list of optimizers supported by transformers as of `v4.54.0`:
- `adamw_torch`
- `adamw_torch_fused`
- `adamw_torch_xla`
- `adamw_torch_npu_fused`
- `adamw_apex_fused`
- `adafactor`
- `adamw_anyprecision`
- `adamw_torch_4bit`
- `adamw_torch_8bit`
- `ademamix`
- `sgd`
- `adagrad`
- `adamw_bnb_8bit`
- `adamw_8bit` # alias for adamw_bnb_8bit
- `ademamix_8bit`
- `lion_8bit`
- `lion_32bit`
- `paged_adamw_32bit`
- `paged_adamw_8bit`
- `paged_ademamix_32bit`
- `paged_ademamix_8bit`
- `paged_lion_32bit`
- `paged_lion_8bit`
- `rmsprop`
- `rmsprop_bnb`
- `rmsprop_bnb_8bit`
- `rmsprop_bnb_32bit`
- `galore_adamw`
- `galore_adamw_8bit`
- `galore_adafactor`
- `galore_adamw_layerwise`
- `galore_adamw_8bit_layerwise`
- `galore_adafactor_layerwise`
- `lomo`
- `adalomo`
- `grokadamw`
- `schedule_free_radam`
- `schedule_free_adamw`
- `schedule_free_sgd`
- `apollo_adamw`
- `apollo_adamw_layerwise`
- `stable_adamw`
## Custom Optimizers
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
### optimi_adamw
```yaml
optimizer: optimi_adamw
```
### ao_adamw_4bit
Deprecated: Please use `adamw_torch_4bit`.
### ao_adamw_8bit
Deprecated: Please use `adamw_torch_8bit`.
### ao_adamw_fp8
```yaml
optimizer: ao_adamw_fp8
```
### adopt_adamw
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
```yaml
optimizer: adopt_adamw
```
### came_pytorch
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
```yaml
optimizer: came_pytorch
# optional args (defaults below)
adam_beta1: 0.9
adam_beta2: 0.999
adam_beta3: 0.9999
adam_epsilon: 1e-30
adam_epsilon2: 1e-16
```
### muon
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
```yaml
optimizer: muon
```
### dion
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
Note: Implementation written for PyTorch 2.7+ for DTensor
```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
sequence_parallel_degree: 8
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

53
examples/arcee/README.md Normal file
View File

@@ -0,0 +1,53 @@
# Finetune ArceeAI's AFM with Axolotl
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/arcee/afm-4.5b-qlora.yaml
```
This config uses about 7.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,64 @@
base_model: arcee-ai/AFM-4.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -47,7 +47,6 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
]
},
{

View File

@@ -10,17 +10,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run the finetuning example:

View File

@@ -0,0 +1,52 @@
# ND Parallelism Examples
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
## Quick Start
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the command below:
```bash
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
```
## Example Configurations
### Single Node (8 GPUs)
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
- Uses all 3 parallelism dimensions on a single node
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
```yaml
dp_shard_size: 2 # FSDP across 2 GPUs
tensor_parallel_size: 2 # TP across 2 GPUs
context_parallel_size: 2 # CP across 2 GPUs
# Total: 2 × 2 × 2 = 8 GPUs
```
### Multi-Node
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
```yaml
dp_shard_size: 4 # FSDP within each 4-GPU group
tensor_parallel_size: 2 # TP within each node
dp_replicate_size: 2 # DDP across 2 groups
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
```
## Learn More
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,47 @@
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 4
dp_replicate_size: 2
tensor_parallel_size: 2
# context_parallel_size: 2
dataset_prepared_path: last_run_prepared
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1

View File

@@ -0,0 +1,46 @@
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
# dp_replicate_size: 1
context_parallel_size: 2
tensor_parallel_size: 2
dataset_prepared_path: last_run_prepared
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:

View File

@@ -4,17 +4,14 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -0,0 +1,74 @@
# Finetune OpenAI's GPT-OSS with Axolotl
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
```bash
# LoRA SFT linear layers (1x48GB @ ~44GiB)
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
```
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
Here is an example dataset config:
```yaml
datasets:
- path: Nanobit/text-tools-2k-test
type: chat_template
```
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,67 @@
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
use_kernels: false
dp_shard_size: 16 # requires 2x8xH100 nodes
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
cpu_ram_efficient_loading: true

View File

@@ -0,0 +1,58 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
# choose the zero3 configuration that best fits your system capabilities
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -0,0 +1,68 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`

View File

@@ -0,0 +1,64 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -0,0 +1,67 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -45,7 +45,6 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -49,7 +49,6 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -25,9 +25,12 @@ lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
relora_steps: 150
relora_warmup_ratio: 0.1
relora: true
relora_prune_ratio: 0.9
relora_cpu_offload: false
jagged_restart_steps: 150
jagged_restart_warmup_steps: 10
jagged_restart_anneal_steps: false
wandb_project:
wandb_entity:

View File

@@ -8,17 +8,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run the finetuning example:

View File

@@ -27,7 +27,6 @@ sequence_len: 2048
sample_packing: true
eval_sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,7 +26,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,7 +26,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

66
examples/slurm/README.md Normal file
View File

@@ -0,0 +1,66 @@
# SLURM Multi-Node Training
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
## Prerequisites
- Access to a SLURM cluster with GPU nodes
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
## Usage
### Standard SLURM Clusters
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
2. Place your Axolotl config file (`train.yaml`) in the same directory.
3. Set the appropriate environment variables for the job:
```bash
export HF_TOKEN="your-huggingface-token"
# metric tracking
# export WANDB_API_KEY="your-wandb-api-key"
# ...
```
4. Submit the job:
```bash
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
```
Where:
- `NUM_NODES`: Number of nodes to use
- `NUM_TRAINERS`: GPUs per node (typically 8)
- `PRIMARY_ADDR`: Hostname/IP of the master node
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
5. (Optional) Run other slurm commands:
```bash
# check job info
scontrol show job axolotl-cli
# check job queue
squeue
# check cluster status
sinfo
```
### RunPod Instant Clusters
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
1. **Deploy a SLURM Cluster**:
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
- Click "Create a Cluster"
- Choose your GPU type, node count, and region
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
- Deploy the cluster
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
## Additional Resources
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)

View File

@@ -0,0 +1,20 @@
#!/bin/bash
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
# export HF_TOKEN="..."
# export WANDB_API_KEY="..."
#
# ---------- SBATCH commands ---------- #
#SBATCH --job-name=axolotl-slurm-multinode
#SBATCH --ntasks-per-node=1
#SBATCH --nodes=$NUM_NODES
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=128
export TORCH_DIST_INIT_BARRIER=0
srun axolotl preprocess train.yaml
srun axolotl train train.yaml --launcher torchrun -- \
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"

View File

@@ -6,17 +6,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Please install the below.

View File

@@ -1,30 +1,33 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.0
triton>=3.0.0
bitsandbytes==0.46.1
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.6.0
liger-kernel==0.6.1
# END section
packaging==23.2
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.54.0
peft==0.17.0
transformers @ git+https://github.com/vasqu/transformers@fix-fa-integration
tokenizers>=0.21.1
accelerate==1.9.0
accelerate==1.10.0
datasets==4.0.0
deepspeed>=0.17.0
trl==0.19.1
trl==0.21.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.23.3
gradio==5.41.1
modal==1.0.2
pydantic==2.10.6
@@ -66,6 +69,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3

View File

@@ -44,8 +44,13 @@ add_keys_to_authorized() {
chmod 700 -R ~/.ssh
}
# Set SSH port
if [ ! -z "$SSH_PORT" ]; then
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
fi
if [[ $PUBLIC_KEY ]]; then
# runpod
# runpod, prime intellect
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
@@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# start the runpod slurm init
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
if [[ -f "$SLURM_INIT" ]]; then
echo "[entrypoint] running $SLURM_INIT..."
bash "$SLURM_INIT"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
)

View File

@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.12.0.dev"
__version__ = "0.13.0.dev"

View File

@@ -30,8 +30,6 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=0)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass

View File

@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
from typing import Literal
import yaml
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
@@ -31,9 +31,10 @@ def do_cli_preprocess(
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cloud_config: Path | str,
config: Path | str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
cwd=None,
**kwargs,
) -> None:
@@ -44,12 +45,18 @@ def do_cli_train(
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
cloud.train(
config_yaml,
launcher=launcher,
launcher_args=launcher_args,
local_dirs=local_dirs,
**kwargs,
)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)

View File

@@ -3,6 +3,7 @@ base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
from typing import Literal
class Cloud(ABC):
@@ -15,5 +16,12 @@ class Cloud(ABC):
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
pass

View File

@@ -8,7 +8,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
from typing import Literal
import modal
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
launcher=launcher,
launcher_args=launcher_args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _train(
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs, # pylint: disable=unused-argument
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
if accelerate:
accelerate_args = "--accelerate"
launcher_args = launcher_args or []
# Build the base command
if launcher == "accelerate":
launcher_arg = "--launcher accelerate"
elif launcher == "torchrun":
launcher_arg = "--launcher torchrun"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
launcher_arg = "--launcher python"
# Build launcher args string
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
run_cmd(
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
run_folder,
volumes,
)

View File

@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(
@@ -200,14 +199,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
for key, value in kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[k] = kwargs[k]
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -9,7 +9,6 @@ from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
@@ -268,5 +267,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -4,12 +4,9 @@
import os
import subprocess # nosec B404
import tempfile
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
VllmServeCliArgs,
)
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
generate_config_files,
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
LAUNCHER_COMMAND_MAPPING = {
"accelerate": ["accelerate", "launch"],
"torchrun": ["torchrun"],
}
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
@cli.command()
@@ -50,7 +55,7 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
"""
Preprocess datasets before training.
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
@@ -88,126 +95,82 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def train(
ctx: click.Context,
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
**kwargs,
) -> None:
):
"""
Train or fine-tune a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
else:
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
# Process each configuration
for cfg_file, is_group in generate_config_files(config, sweep):
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
finally:
# Only delete temp files, not the original config
if cfg_file != config:
os.unlink(cfg_file)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU evaluation",
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
"""
Evaluate a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.evaluate"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -218,30 +181,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU inference",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
@click.pass_context
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
"""
Run inference with a trained model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.inference"]
)
if config:
base_cmd.append(config)
if gradio:
@@ -254,33 +229,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for weight merging",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def merge_sharded_fsdp_weights(
ctx: click.Context, config: str, launcher: str, **kwargs
):
"""
Merge sharded FSDP model weights.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -296,7 +280,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None:
def merge_lora(config: str, **kwargs):
"""
Merge trained LoRA adapters into a base model.
@@ -313,7 +297,7 @@ def merge_lora(config: str, **kwargs) -> None:
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]) -> None:
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
@@ -351,7 +335,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
def delinearize_llama4(model: str, output: str):
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
@@ -365,5 +349,4 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
@@ -70,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_size=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -17,7 +17,6 @@ from accelerate.utils import (
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs
@@ -109,5 +108,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -7,7 +7,6 @@ from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,330 +0,0 @@
"""Utility methods for axolotl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])
return cmd
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "new"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,23 @@
"""Init for axolotl.cli.utils module."""
from .args import (
add_options_from_config,
add_options_from_dataclass,
filter_none_kwargs,
)
from .fetch import fetch_from_github
from .load import load_model_and_tokenizer
from .sweeps import generate_sweep_configs
from .train import build_command, generate_config_files, launch_training
__all__ = [
"filter_none_kwargs",
"add_options_from_dataclass",
"add_options_from_config",
"build_command",
"generate_config_files",
"generate_sweep_configs",
"load_model_and_tokenizer",
"launch_training",
"fetch_from_github",
]

View File

@@ -0,0 +1,120 @@
"""Utilities for axolotl CLI args."""
import dataclasses
from functools import wraps
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
from pydantic import BaseModel
def _strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator

View File

@@ -0,0 +1,142 @@
"""Utilities for axolotl fetch CLI command."""
import concurrent.futures
import hashlib
import json
from pathlib import Path
import click
import requests
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "updated"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
_download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")

View File

@@ -0,0 +1,52 @@
"""Utilities for model, tokenizer, etc. loading."""
from typing import Any
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,222 @@
"""Utilities for axolotl train CLI command."""
import os
import subprocess # nosec
import sys
import tempfile
from typing import Any, Iterator, Literal
import yaml
from axolotl.cli.utils.sweeps import generate_sweep_configs
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
"""
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
Args:
launcher_args: List of launcher arguments
Returns:
Updated launcher args with defaults added if needed
"""
args = launcher_args.copy()
# Check if rdzv_endpoint is present
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
if has_rdzv_endpoint:
# Check if rdzv_backend is already provided
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
if not has_rdzv_backend:
args.extend(["--rdzv_backend", "c10d"])
# Check if rdzv_id is already provided
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
if not has_rdzv_id:
import uuid
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
return args
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
cmd.append(f"--{key}={value}")
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process.
Args:
config: Base configuration file
sweep: Sweep configuration file
Yields:
Tuple of configuration file name and whether this is a group of configurations
"""
if not sweep:
yield config, False
return
# Load sweep and base configurations
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",
delete=False,
encoding="utf-8",
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name, is_group
def launch_training(
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
if cloud:
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
elif launcher is None:
# handle ray train launch
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
cloud: str,
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
kwargs: dict,
launcher_args: list[str] | None = None,
) -> None:
"""Execute training via cloud launcher."""
from axolotl.cli.cloud import do_cli_train
launcher_args = launcher_args or []
cwd = os.getcwd() if launcher else None
do_cli_train(
cloud_config=cloud,
config=cfg_file,
launcher=launcher or "accelerate",
launcher_args=launcher_args,
cwd=cwd,
**kwargs,
)
def _launch_accelerate_training(
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
internal_launcher_args = []
# Extract launcher-specific arguments from kwargs (legacy support)
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port")
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes")
internal_launcher_args.extend(["--num_processes", str(num_processes)])
# Combine internal args with user-provided launcher args
all_launcher_args = internal_launcher_args + launcher_args
base_cmd = (
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
)
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
# Add default RDZV arguments if rdzv_endpoint is set
launcher_args = _add_default_rdzv_args(launcher_args)
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
"""Execute training via python launcher."""
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)

View File

@@ -2,12 +2,10 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@@ -42,13 +40,17 @@ def do_vllm_serve(
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = 1
data_parallel_size = 1
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
@@ -81,63 +83,3 @@ def do_vllm_serve(
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssDecoderLayer",
}

View File

@@ -27,18 +27,18 @@ import torch
from transformers import (
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
@@ -139,8 +139,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
@@ -268,27 +266,24 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
optimizer_kwargs["device_mesh"] = device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
@@ -435,7 +430,11 @@ class TrainerBuilderBase(abc.ABC):
def _configure_accelerator_config(self, training_args_kwargs: dict):
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
**self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:
@@ -495,10 +494,20 @@ class TrainerBuilderBase(abc.ABC):
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

View File

@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
@@ -44,6 +43,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -58,7 +58,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.relora_steps:
if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg))
if (
@@ -131,14 +131,24 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
return trainer_cls
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return AxolotlTrainer
def build(self, total_num_steps):
@@ -271,20 +281,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora and self.cfg.jagged_restart_steps:
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.jagged_restart_steps:
training_arguments_kwargs["jagged_restart_steps"] = (
self.cfg.jagged_restart_steps
)
if self.cfg.jagged_restart_warmup_steps:
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
self.cfg.jagged_restart_warmup_steps
)
if self.cfg.jagged_restart_anneal_steps:
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
self.cfg.jagged_restart_anneal_steps
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (
@@ -348,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
elif self.cfg.pad_to_sequence_len is None:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple

View File

@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -53,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):

View File

@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,

View File

@@ -10,8 +10,11 @@ from functools import partial, wraps
from typing import Any, Callable, Literal, Optional
import datasets
import safetensors
import torch
from accelerate.state import AcceleratorState
from datasets import Dataset
from peft import PeftModel
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -19,14 +22,17 @@ from torch.utils.data import (
Sampler,
SequentialSampler,
)
from transformers import Trainer
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -37,6 +43,8 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -50,6 +58,7 @@ class AxolotlTrainer(
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -511,7 +520,18 @@ class AxolotlTrainer(
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
res = super().create_accelerator_and_postprocess()
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
@@ -520,8 +540,6 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
@@ -558,6 +576,17 @@ class AxolotlTrainer(
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if is_main_process():
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -575,3 +604,64 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
else (PreTrainedModel, PeftModel)
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
):
self.accelerator.unwrap_model(
self.model, keep_torch_compile=False
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -8,7 +8,11 @@ import torch
from torch import nn
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DPOTrainer,
DistributedParallelMixin,
):
"""Extend the base DPOTrainer for axolotl helpers."""

View File

@@ -49,7 +49,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
@@ -82,8 +83,13 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if cfg.context_parallel_size > 1:
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
if trl.importance_sampling_level is not None:
grpo_args_kwargs["importance_sampling_level"] = (
trl.importance_sampling_level
)
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None
context_parallel_size: int | None = None

View File

@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
`context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_size: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // context_parallel_size
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)

View File

@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -53,7 +57,12 @@ if is_peft_available():
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
GRPOTrainer,
):
"""Extend the base GRPOTrainer for axolotl helpers"""
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_size
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_size=self.args.context_parallel_size,
shuffle=True,
seed=self.args.seed,
drop_last=True,
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
return dataloader
# Otherwise prepare with accelerator
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // context_parallel_size
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
group_leader_rank = sp_group_id * context_parallel_size
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_size
]
with profiling_context(self, "vLLM.generate"):
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_size
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size

View File

@@ -5,6 +5,7 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""

View File

@@ -5,6 +5,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin

View File

@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
except (NotImplementedError, KeyError) as exc:
# TODO: fix fsdp2 optimizer saving
LOG.warning_once(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -0,0 +1,33 @@
"""
Mixin for correctly saving fsdp
"""
from accelerate import PartialState
from transformers import Trainer
class DistributedParallelMixin(Trainer):
"""
Mixin for correctly saving fsdp
"""
def _save(self, output_dir: str | None = None, state_dict=None):
if (
state_dict is None
and self.accelerator.parallelism_config
and self.accelerator.parallelism_config.dp_shard_enabled
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)
def create_accelerator_and_postprocess(self):
super().create_accelerator_and_postprocess()
if (
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
PartialState().distributed_type = "MULTI_GPU"

View File

@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
JaggedLRRestartScheduler,
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning(
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restart_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restart_anneal_steps or 1
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
)
return self.lr_scheduler # type: ignore

View File

@@ -1,46 +0,0 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler # type: ignore

View File

@@ -8,13 +8,18 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
ORPOTrainer,
):
"""
Extend the base ORPOTrainer for axolotl helpers
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
KTOTrainer,
):
"""
Extend the base KTOTrainer for axolotl helpers
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
CPOTrainer,
):
"""
Extend the base CPOTrainer for axolotl helpers
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
RewardTrainer,
):
"""
Extend the base RewardTrainer for axolotl helpers
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
PRMTrainer,
):
"""
Extend the base trl.PRMTrainer for axolotl helpers

View File

@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restart_warmup_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restart_anneal_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many anneal steps to take before reset for jagged restarts"
},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -235,3 +243,18 @@ class AxolotlTrainingMixins:
)
# end of multi-modal section
dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)

View File

@@ -26,9 +26,11 @@ import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -74,8 +76,8 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
def register(self, cfg: dict): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
cfg: The configuration for the plugin.
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass
# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
```
## Usage
@@ -31,6 +31,7 @@ plugins:
## Supported Models
- arcee
- cohere
- cohere2
- gemma
@@ -41,11 +42,17 @@ plugins:
- gemma3n_text
- glm
- glm4
- gpt_oss
- granite
- granitemoe
- hunyuan_v1_dense
- hunyuan_v1_moe
- llama
- llama4
- llama4_text
- mistral
- mistral3
- mixtral
- mllama
- phi
- phi3
@@ -56,6 +63,8 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- smollm3
- voxtral
## Citation

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
)

View File

@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
target_token_ids = prompt.get("target_token_ids", None)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
return tokenized_prompt

View File

@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before")
@classmethod
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
raise ValueError("Tensor parallelism is not compatible with liger losses.")
return self

View File

@@ -1,11 +0,0 @@
"""
Axolotl custom modeling module
"""
from .args import AxolotlModelingArgs
from .plugin import AxolotlModelingPlugin
__all__ = [
"AxolotlModelingArgs",
"AxolotlModelingPlugin",
]

View File

@@ -1,13 +0,0 @@
"""
Args for using Axolotl custom modeling
"""
from pydantic import BaseModel
class AxolotlModelingArgs(BaseModel):
"""
Args for using Axolotl custom modeling
"""
use_liger_fused_rms_add: bool = False

View File

@@ -1,9 +0,0 @@
"""
Gemma3 modeling
"""
from .modeling_gemma3 import patch_gemma3
__all__ = [
"patch_gemma3",
]

Some files were not shown because too many files have changed in this diff Show More