Compare commits

..

113 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
Dan Saunders
42aadc5069 bench fix 2025-09-19 12:34:08 -04:00
Dan Saunders
1e7302d30a bench fix 2025-09-19 12:20:35 -04:00
Dan Saunders
63544ce709 fix 2025-09-19 11:34:27 -04:00
Dan Saunders
3bfed0aac8 shared expert detection 2025-09-19 11:24:26 -04:00
Dan Saunders
bfc848f81d bits and pieces 2025-09-19 02:12:57 +00:00
Dan Saunders
abe1cad6bc another bench 2025-09-18 13:45:19 -04:00
Dan Saunders
354389caef torchtitan bench 2025-09-18 13:29:20 -04:00
Dan Saunders
efcd032fce yet another refactor 2025-09-18 13:03:28 -04:00
Dan Saunders
7500641601 yet another refactor 2025-09-18 12:47:15 -04:00
Dan Saunders
0295df5bca precompute fuse 2025-09-18 12:10:46 -04:00
Dan Saunders
b39ef54833 combine mult 2025-09-18 12:08:03 -04:00
Dan Saunders
ad4cd39bcd remove contig 2025-09-18 11:55:15 -04:00
Dan Saunders
5c197275ad inplace 2025-09-18 11:51:17 -04:00
Dan Saunders
19c91e3675 refactor 2025-09-18 11:44:21 -04:00
Dan Saunders
2a176e4923 fix 2025-09-18 11:29:33 -04:00
Dan Saunders
7d867de9b2 refactor 2025-09-18 11:23:15 -04:00
Dan Saunders
01b6792c2e refactor 2025-09-18 11:20:08 -04:00
Dan Saunders
bbf1f14ca4 dtype issues 2025-09-17 23:52:18 +00:00
Dan Saunders
c6878beb7d simplify 2025-09-17 19:15:34 -04:00
Dan Saunders
e62979d11d fix 2025-09-17 18:53:07 -04:00
Dan Saunders
d57b9c67c2 log 2025-09-17 18:52:27 -04:00
Dan Saunders
eaaf16aa00 cumulative offsets 2025-09-17 18:45:15 -04:00
Dan Saunders
f3b953e222 fix? 2025-09-17 18:42:10 -04:00
Dan Saunders
7935dc0911 dtype fix 2025-09-17 18:36:22 -04:00
Dan Saunders
d2b49b2670 error msg 2025-09-17 18:29:30 -04:00
Dan Saunders
b5cb345ca4 fix test 2025-09-17 18:24:00 -04:00
Dan Saunders
03d4c2683e fix perf degradation 2025-09-17 18:20:37 -04:00
Dan Saunders
fd87eed501 minify 2025-09-17 16:42:35 -04:00
Dan Saunders
129db67705 fix 2025-09-17 16:24:29 -04:00
Dan Saunders
38b890a36b fix 2025-09-17 16:16:41 -04:00
Dan Saunders
180920c7bf simplify 2025-09-17 19:49:18 +00:00
Dan Saunders
d024048d74 logs + fix 2025-09-17 14:50:49 -04:00
Dan Saunders
98dc945838 fix 2025-09-17 14:42:53 -04:00
Dan Saunders
108600cd69 update config 2025-09-17 14:36:24 -04:00
Dan Saunders
0e9387c395 fix 2025-09-17 14:35:36 -04:00
Dan Saunders
db61e0d4ff fix 2025-09-17 14:26:25 -04:00
Dan Saunders
51e565f60a logs 2025-09-17 14:15:51 -04:00
Dan Saunders
c774dd0409 refactor + fix 2025-09-17 14:01:39 -04:00
Dan Saunders
7289e0cb55 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
8d483c11f7 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
9c1829cf57 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
135b09d1de logs, qwen2 support 2025-09-17 13:44:26 -04:00
Dan Saunders
de4344a56e patch 2025-09-17 13:44:26 -04:00
Dan Saunders
7d572b58d1 just grouped_mm for now 2025-09-17 13:44:26 -04:00
Dan Saunders
773d7e4291 update 2025-09-17 13:44:26 -04:00
Dan Saunders
fef47a5b7c hardening 2025-09-17 13:44:26 -04:00
Dan Saunders
f6ed8ddc01 fix 2025-09-17 13:44:26 -04:00
Dan Saunders
556d6448fe fix 2025-09-17 13:44:26 -04:00
Dan Saunders
5c2229721d diag 2025-09-17 13:44:26 -04:00
Dan Saunders
d7de6b0e96 grouped_mm 2025-09-17 13:44:26 -04:00
Dan Saunders
3c6648678f numerics 2025-09-17 13:44:26 -04:00
Dan Saunders
5b19a1ea9c improve 2025-09-17 13:44:26 -04:00
Dan Saunders
cfefad1eea fix 2025-09-17 13:44:26 -04:00
Dan Saunders
125e7b5fe6 fast path 2025-09-17 13:44:26 -04:00
Dan Saunders
479b6144df tflops 2025-09-17 13:44:26 -04:00
Dan Saunders
68da65cba2 update 2025-09-17 13:44:26 -04:00
Dan Saunders
0d689bb421 cache, example 2025-09-17 13:44:26 -04:00
Dan Saunders
43ada1278a moe kernels init scaffold 2025-09-17 13:44:26 -04:00
Dan Saunders
4065bc14c6 Debug log, logging improvements (#3159)
* simplify logging

* remove comment

* progress on debug.log

* add debug-level logger for file log

* simplify

* case insensitivity; 3rd party logging improvements

* simplify

* fix

* tests

* lint

* nits

* nit

* Update tests/test_utils_tee.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* cleanup / comments

* fix

* oops

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-09-17 13:27:03 -04:00
salman
e5c427f6de qat doc updates (#3162) [skip-ci] 2025-09-17 10:38:15 +01:00
Wing Lian
86d6ee7c05 upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0

* upgrade accelerate patch fix

* add hints when using gradient_checkpointing with DPO

* set gradient-checpointing properly
2025-09-16 14:53:01 -04:00
Wing Lian
d4cff1b7bb improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci]
* improve setting of NCCL_P2P_DISABLE on runpod

* use recs from review
2025-09-16 14:52:45 -04:00
Wing Lian
1ef6c196f7 setup env vars for ray train for FSDP (#3130) [skip ci] 2025-09-16 14:52:29 -04:00
salman
58d67bf98d Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107) 2025-09-12 10:55:50 +01:00
salman
0401a15888 SEO go brrr (#3153) [skip-ci] 2025-09-12 10:55:11 +01:00
NanoCode012
fcfc13d710 feat(doc): update thinking and chat_template notes (#3114) [skip ci]
* feat: update thinking and chat_template notes

* fix: grammar
2025-09-12 14:45:18 +07:00
salman
9406c0c488 log before eval step (#3148) [skip-ci] 2025-09-11 11:19:30 +01:00
Dan Saunders
1b53c49e1a text diffusion training plugin (#3067)
* diffusion training plugin

* cleanup

* nits

* fixes + improvements

* add back in reinit_weights (clobbered?); masking / pretrain fixes

* nits

* cleanup; tests draft

* sample generation, tests fixes

* fixes

* nits

* add inference support; add auto-mask token support

* nits

* nits

* progress

* simplify logging

* lint

* prefix args with diffusion_

* coderabbito

* tests fix

* nit

* nits

* cleanup + nits

* nits

* fix SFT sample gen

* fixes

* fix

* comments

* comments

* lint

* reward model lora fix

* cleanup; fix pretraining_dataset case

* gradio inference

* update cfgs

* update cfgs

* train, generation parity, cleanup

* fix

* simplify

* test

* test fix
2025-09-10 20:27:00 -04:00
NanoCode012
b71482cec5 Feat: add hunyuan v1 (#3016)
* feat: add hunyuan cce support

* feat: update cce docs

* feat: add multipack support for granite and hunyuan

* feat: add hunyuan docs and example config

* feat: update readme instructions to include CCE installation

* fix: chat template log appearing despite tokenizer already having template

* feat: add vram usage

* fix: remove duplicate cce install

* fix: use latest commit of PR in case rebased/pushed

* Revert "fix: use latest commit of PR in case rebased/pushed"

This reverts commit 8b60aa00de.

* feat: update doc as upstream merged
2025-09-10 09:03:30 +07:00
NanoCode012
79103b01ca Feat: add seedoss (#3104) [skip ci]
* feat: add seedoss cce

* feat: add seedoss config and docs

* fix: shouldn't have target modules with target linear

* feat: add vram numbers

* fix: hf link

* fix: name

* fix: support multipack seedoss

* fix: merge error

* feat: update seedoss instructions for transformers release
2025-09-10 09:01:02 +07:00
salman
9640338d37 Default include_tkps to true (#3134)
* default true

* force e2e

* causal trainer only

* fix eval loggin [skip-ci]

* revert setup.py

* force tests

* guarding

* guarding

* fix test case

* use evaluate [skip-e2e]

* use evaluate [skip-e2e]

* kick off ci

* fixing

* reverting
2025-09-09 10:50:21 -04:00
Wing Lian
b5d4c7ff54 allow 1% deviation for codecov (#3138) [skip ci] 2025-09-07 11:01:03 -04:00
Seungduk Kim
8fd9221f13 Add ipo as an rl type that shares DPODataset config (#3128)
* Add `ipo` as an `rl` type that shares DPODataset config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-07 10:49:10 -04:00
github-actions[bot]
bf00f29f3a chore: update pre-commit hooks (#3137) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-07 10:33:20 -04:00
NanoCode012
1d32278755 feat: upgrade transformers to v4.56.1 (#3127)
* feat: upgrade transformers to v4.56

* fix handling of CP/SP now that position_ids are default even for unpacked sequences

* feat: monkeypatch list_repo_templates

* fix: apply patch for tests only

* see if updated main works at least

* fix: update to patch release and remove monkeypatch

* remove fsdp2 eval patch

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-05 11:00:54 -04:00
NanoCode012
c6ae5c43cb fix: chat template jinja file not being loaded during inference (#3112)
* fix: chat template jinja file not being loaded during inference

* fix: bot comment
2025-09-03 16:25:09 -04:00
yardenhoch
efa1da52d5 Center rewards coefficient (#3124)
* feat: add center_rewards_coefficient for reward modeling

- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows

Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.

* Update description

* test: add unit tests for center_rewards_coefficient integration

* Update src/axolotl/core/builders/base.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* reference to TRL documentation.

* add new reward model configuration for qwen3 with comprehensive parameters

* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.

* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient

* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.

* linting

* nit

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-09-03 16:22:37 -04:00
mhenrichsen
48db520d92 Create 270m-qlora.yml (#3075) [skip ci]
Adds 270m gemma3 qlora
2025-09-03 16:20:32 -04:00
NanoCode012
53a0c1f39c feat: add peft_trainable_token_indices (#3062)
* feat: add peft_trainable_token_indices

* feat: add warning compat with fix_untrained_tokens
2025-09-03 01:48:01 -04:00
github-actions[bot]
4cc6038d52 chore: update pre-commit hooks (#3122) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-03 01:41:34 -04:00
NanoCode012
e48aa8a5b1 feat(doc): improve visibility for colab notebooks (#3110) [skip ci]
* feat: improve visibility for colab notebooks

* fix: link to GH colab

* feat: change to badge and move higher
2025-09-03 01:40:53 -04:00
xuyifann
24aba5caca Clamping the len of dataloader to minimum of 1 (#3100) [skip ci]
* Clamping the len of dataloader to minimum of 1

* linter reformat
2025-09-03 01:40:27 -04:00
Wing Lian
06bebcb65f run cu128-2.8.0 e2e tests on B200 (#3126)
* run cu128-2.8.0 e2e tests on B200

* not an int 🤦

* fix yaml
2025-09-02 13:13:23 -04:00
Dan Saunders
231a67e70b Streaming SFT support (#3101)
* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
2025-09-02 12:08:44 -04:00
Wing Lian
0094a2d744 support for tiledmlp for GPT-OSS (#3116)
* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS

* add logging back

* update deps
2025-08-29 13:52:49 -04:00
Wing Lian
7ed40f1d70 automatically set env vars for single gpu deepspeed zero3 (#3118) [skip ci]
* automatically set env vars for single gpu deepspeed zero3

* use setdefault
2025-08-29 13:36:47 -04:00
VED
5b6ec2820f patch for ds_grads_remaining in deepspeed (#3102) [skip ci]
* patch deepspeed

* deepspeed patch for ds_grads_remaining

* patch in Patchmanager

* chore: lint

* deepseed utils

* chore2

* patch ds_grads_remaining chore

* chore lint

* chore lint

* remove torch.nn patch

* lint

* Update src/axolotl/monkeypatch/utils.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patched with checkpointwarapper

* lint

* only apply deepspeed patch when using activation offloading

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-29 12:12:09 -04:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
Dan Saunders
79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00
Dan Saunders
eea7a006e1 make multipack sampler patch explicit (#3096)
* make multipack sampler patch explicit

* combining
2025-08-22 14:29:10 -04:00
Wing Lian
ab4d604a8f upgrade peft for 0.17.1 (#3094)
* upgrade peft to 0.17.1

* upgrade for transformers too
2025-08-22 07:26:30 -04:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
384 changed files with 24084 additions and 12819 deletions

View File

@@ -1,3 +1,3 @@
[bandit] [bandit]
exclude = tests exclude = tests
skips = B101,B615 skips = B101,B615,B102,B110

View File

@@ -12,5 +12,6 @@ reviews:
auto_review: auto_review:
enabled: true enabled: true
drafts: false drafts: false
auto_incremental_review: false
chat: chat:
auto_reply: true auto_reply: true

View File

@@ -1,5 +0,0 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.7.1
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.7.1
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1 pytorch: 2.7.1
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm axolotl_extras: vllm
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"] pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"] pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
dockerfile: "Dockerfile-uv.jinja" dockerfile: "Dockerfile-uv.jinja"
@@ -298,6 +298,13 @@ jobs:
pytorch: 2.7.1 pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -318,6 +325,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
@@ -334,10 +342,10 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.7.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:

3
.gitignore vendored
View File

@@ -190,3 +190,6 @@ out/
# vim # vim
*.swp *.swp
# scm auto-versioning
src/axolotl/_version.py

View File

@@ -1,4 +0,0 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,22 +10,12 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 25.1.0 rev: v0.12.12
hooks: hooks:
- id: black - id: ruff
- repo: https://github.com/pycqa/isort args: [--fix]
rev: 6.0.1 - id: ruff-format
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1 rev: v1.17.1
hooks: hooks:

View File

@@ -1,15 +0,0 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -1,6 +1,6 @@
cff-version: 1.2.0 cff-version: 1.2.0
type: software type: software
title: "Axolotl: Post-Training for AI Models" title: "Axolotl: Open Source LLM Post-Training"
message: "If you use this software, please cite it as below." message: "If you use this software, please cite it as below."
authors: authors:
- name: "Axolotl maintainers and contributors" - name: "Axolotl maintainers and contributors"

View File

@@ -5,6 +5,9 @@
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;"> <img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture> </picture>
</p> </p>
<p align="center">
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
</p>
<p align="center"> <p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License"> <img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
@@ -17,6 +20,7 @@
<br/> <br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a> <a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a> <a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/> <br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -49,20 +53,21 @@
## ✨ Overview ## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models. Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Features: Features:
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models. - **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM). - **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference. - **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! - **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start ## 🚀 Quick Start - LLM Fine-tuning in Minutes
**Requirements**: **Requirements**:
@@ -70,6 +75,10 @@ Features:
- Python 3.11 - Python 3.11
- PyTorch ≥2.6.0 - PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation ### Installation
#### Using pip #### Using pip
@@ -155,7 +164,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex ```bibtex
@software{axolotl, @software{axolotl,
title = {Axolotl: Post-Training for AI Models}, title = {Axolotl: Open Source LLM Post-Training},
author = {{Axolotl maintainers and contributors}}, author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl}, url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0}, license = {Apache-2.0},

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed - utils.distributed
- utils.dict - utils.dict
- utils.optimizers.adopt - utils.optimizers.adopt
- utils.data.pretraining - utils.data.streaming
- utils.data.sft - utils.data.sft
- utils.quantization - utils.quantization
- title: Schemas - title: Schemas
@@ -272,6 +272,7 @@ website:
contents: contents:
- docs/batch_vs_grad.qmd - docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd - docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd - docs/multipack.qmd
- docs/mixed_precision.qmd - docs/mixed_precision.qmd
- docs/optimizers.qmd - docs/optimizers.qmd
@@ -284,6 +285,7 @@ website:
- docs/custom_integrations.qmd - docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd - docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd - docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd - docs/nd_parallelism.qmd
- section: "Troubleshooting" - section: "Troubleshooting"

View File

@@ -2,8 +2,6 @@
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code
import os import os
import pathlib import pathlib
import tempfile import tempfile
@@ -63,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess. # Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)
@app.function( @app.function(

View File

@@ -1,7 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os import os
import pathlib import pathlib
import tempfile import tempfile
@@ -59,7 +57,8 @@ VOLUME_CONFIG = {
} }
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = f"L40S:{N_GPUS}" GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str): def run_cmd(cmd: str, run_folder: str):
@@ -70,4 +69,4 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess. # Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)

View File

@@ -12,7 +12,7 @@ coverage:
default: default:
# basic # basic
target: auto target: auto
threshold: 0% threshold: 1%
base: auto base: auto
# advanced # advanced
branches: null branches: null
@@ -27,7 +27,7 @@ coverage:
default: default:
# basic # basic
target: auto target: auto
threshold: 0% threshold: 1%
base: auto base: auto
# advanced # advanced
branches: null branches: null

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab} ### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb). [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
## Platform-Specific Instructions {#sec-platform-specific} ## Platform-Specific Instructions {#sec-platform-specific}

18
docs/moe_backends.md Normal file
View File

@@ -0,0 +1,18 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -63,15 +63,6 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
::: :::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note} ::: {.callout-note}

View File

@@ -23,10 +23,17 @@ To enable QAT in axolotl, add the following to your configuration file:
```yaml ```yaml
qat: qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
``` ```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`
- `Float8DynamicActivationInt4Weight`
- `NVFP4`
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this. Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
```yaml ```yaml
base_model: # The path to the model to quantize. base_model: # The path to the model to quantize.
quantization: quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
@@ -39,9 +39,8 @@ you used to train the model:
# qat.yml # qat.yml
qat: qat:
activation_dtype: int8 activation_dtype: int8
weight_dtype: int8 weight_dtype: int4
group_size: 256 group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved. output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
``` ```
@@ -51,3 +50,11 @@ axolotl quantize qat.yml
``` ```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it. This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models ### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step). Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml ```yaml
base_model: google/gemma-2-2b base_model: google/gemma-2-2b

View File

@@ -47,7 +47,6 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel.""" """Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel) return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any: def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations.""" """Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+) # Handle Annotated types (Python 3.9+)
@@ -124,7 +123,6 @@ class QuartoGenerator:
return field_type return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type( def _extract_all_pydantic_models_from_type(
self, field_type self, field_type
) -> list[type[BaseModel]]: ) -> list[type[BaseModel]]:
@@ -318,7 +316,6 @@ class QuartoGenerator:
return all_groups return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source( def _extract_field_groups_from_source(
self, model_class: type[BaseModel] self, model_class: type[BaseModel]
) -> list[dict]: ) -> list[dict]:
@@ -503,7 +500,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema() nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {}) nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", []) nested_required = nested_schema.get("required", [])
except Exception: # pylint: disable=broad-exception-caught except Exception:
# Fallback: use model fields directly # Fallback: use model fields directly
nested_properties = {} nested_properties = {}
nested_required = [] nested_required = []
@@ -607,7 +604,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema() schema = model_class.model_json_schema()
properties = schema.get("properties", {}) properties = schema.get("properties", {})
required = schema.get("required", []) required = schema.get("required", [])
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e:
print( print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead." f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
) )

120
docs/streaming.qmd Normal file
View File

@@ -0,0 +1,120 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```
2. Run the finetuning example: 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
```bash ```bash
axolotl train examples/devstral/devstral-small-qlora.yml axolotl train examples/devstral/devstral-small-qlora.yml

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -41,6 +41,12 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
``` ```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -61,9 +67,23 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model ### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM. for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
@@ -86,6 +106,16 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info. Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### Thinking and chat_template masking conflict
OpenAIs Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotls `chat_template` masking.
If your dataset has `thinking` content mid-turn, there are two paths we recommend:
- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).
- Adjust your dataset to only have `thinking` content in the last turn.
### TIPS ### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). - Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking field_thinking: thinking
template_thinking_key: thinking template_thinking_key: thinking
dataset_prepared_path: last_run_prepared dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/gpt-oss-out/ output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true tf32: true
flash_attention: true flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true gradient_checkpointing: true
activation_offloading: true activation_offloading: true

View File

@@ -0,0 +1,85 @@
# Finetune HunYuan with Axolotl
Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml
```
This config uses about 4.7 GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Dataset
HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.
```python
# fast think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\n\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
# slow think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\nThe user is asking about the color of the sun. I need to ...\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
```
### TIPS
- For inference, the official Tencent team recommends
```json
{
"do_sample": true,
"top_k": 20,
"top_p": 0.8,
"repetition_penalty": 1.05,
"temperature": 0.7
}
```
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,64 @@
base_model: tencent/Hunyuan-0.5B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,20 +15,18 @@ liger_glu_activation: true
liger_layer_norm: true liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
datasets: datasets:
- path: yahma/alpaca-cleaned - path: yahma/alpaca-cleaned
type: alpaca type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/ output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: true sample_packing: false
sequence_len: 8192
sequence_len: 512 flash_attention: true
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat: qat:
activation_dtype: int8 activation_dtype: int8
@@ -67,7 +65,7 @@ fsdp:
fsdp_config: fsdp_config:
fsdp_version: 2 fsdp_version: 2
fsdp_offload_params: false fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
@@ -76,6 +74,6 @@ fsdp_config:
fsdp_activation_checkpointing: true fsdp_activation_checkpointing: true
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config # save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,64 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -9,15 +9,18 @@ pretraining_dataset:
field: text field: text
plugins: plugins:
- diffusion.DiffusionPlugin - axolotl.integrations.diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.15 diffusion:
max_mask_ratio: 0.85 noise_schedule: cosine
eps: 5e-4 min_mask_ratio: 0.15
importance_weighting: true max_mask_ratio: 0.85
mask_token_id: 128002 num_diffusion_steps: 128
generate_samples: true eps: 5e-4
generation_interval: 10 importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out output_dir: ./outputs/model-out
@@ -27,21 +30,17 @@ sample_packing: true
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
micro_batch_size: 4 micro_batch_size: 4
max_steps: 10000 max_steps: 10000
warmup_ratio: 0.1
optimizer: adamw_8bit optimizer: adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 3e-4 learning_rate: 3e-4
sdp_attention: true
bf16: auto bf16: auto
tf32: true tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
sdp_attention: true
warmup_steps: 1000
save_strategy: steps save_strategy: steps
save_steps: 1000 save_steps: 1000

View File

@@ -8,14 +8,18 @@ datasets:
val_set_size: 0.05 val_set_size: 0.05
plugins: plugins:
- diffusion.DiffusionPlugin - axolotl.integrations.diffusion.DiffusionPlugin
noise_schedule: cosine
min_mask_ratio: 0.1 diffusion:
max_mask_ratio: 0.9 noise_schedule: cosine
num_diffusion_steps: 128 min_mask_ratio: 0.1
eps: 1e-3 max_mask_ratio: 0.9
importance_weighting: true num_diffusion_steps: 128
mask_token_id: 128002 eps: 1e-3
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out output_dir: ./outputs/model-out
@@ -26,6 +30,7 @@ eval_sample_packing: true
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 1 num_epochs: 1
warmup_steps: 0.1
optimizer: adamw_8bit optimizer: adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
@@ -36,15 +41,11 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1
sdp_attention: true sdp_attention: true
warmup_steps: 1000 logging_steps: 1
save_strategy: best
save_strategy: steps eval_strategy: epoch
eval_strategy: steps
save_steps: 500
eval_steps: 500
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"

View File

@@ -18,7 +18,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```
2. Run the finetuning example: 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
```bash ```bash
axolotl train examples/magistral/magistral-small-qlora.yaml axolotl train examples/magistral/magistral-small-qlora.yaml

View File

@@ -0,0 +1,53 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
# Keep VRAM low
load_in_8bit: false
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/qwen2-moe-qlora-10gb
# Train small to fit 10GB
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 5
flash_attention: true
warmup_ratio: 0.03
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
model_config:
output_router_logits: true
special_tokens:

View File

@@ -0,0 +1,44 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -0,0 +1,54 @@
# Finetune ByteDance's Seed-OSS with Axolotl
[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml
```
This config uses about 27.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [ByteDance Seed Website](https://seed.bytedance.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: ByteDance-Seed/Seed-OSS-36B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,50 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -0,0 +1,57 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,55 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -22,6 +22,9 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# audio # audio
pip3 install librosa==0.11.0 pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3' pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
``` ```
3. Run the finetuning example: 3. Run the finetuning example:

View File

@@ -26,3 +26,34 @@ include-package-data = true
[tool.setuptools.cmdclass] [tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand" build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B", "I"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -2,8 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0 bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE triton>=3.0.0
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
@@ -13,13 +12,13 @@ liger-kernel==0.6.1
packaging==23.2 packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft==0.17.0 peft>=0.17.0
transformers==4.55.2 transformers==4.56.1
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.10.0 accelerate==1.10.1
datasets==4.0.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.21.0 trl==0.23.0
hf_xet==1.1.5 hf_xet==1.1.5
kernels==0.9.0 kernels==0.9.0
trackio trackio
@@ -65,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.12.0 torchao==0.13.0
schedulefree==1.4.1 schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6

209
scripts/bench_moe.py Normal file
View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse
import sys
import time
import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in range(warmup):
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = load_hf_block(
args.hidden,
args.inter,
args.experts,
args.top_k,
device=device,
dtype=dtype,
)
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available():
return torch.empty(0)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()
if tg is None or not tg.available():
print("torch_grouped\tN/A (unavailable)")
return
y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import torch
# Ensure torchtitan is importable when running from the axolotl tree
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument(
"--score-before",
action="store_true",
help="Apply routing scores before expert computation (default: after)",
)
p.add_argument(
"--score-func",
choices=["softmax", "sigmoid"],
default="softmax",
)
p.add_argument(
"--route-norm",
action="store_true",
help="Enable Torchtitan router normalization when using sigmoid scores.",
)
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
# Two up projections + one down projection per expert/token combination.
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(
moe: MoE,
*,
device: torch.device,
dtype: torch.dtype,
) -> MoE:
moe = moe.to(device=device)
for param in moe.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
buffers = dict(moe.named_buffers())
for name, buf in buffers.items():
if name == "tokens_per_expert":
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
moe._buffers[name] = buf.to(device=device, dtype=dtype)
moe.eval()
return moe
@torch.inference_mode()
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for _ in range(warmup):
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def main() -> None:
args = _parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = _map_dtype(args.dtype)
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=True,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
moe_grouped.init_weights(args.init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=False,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
def run_naive():
return _forward_fn(moe_naive, x)
def run_grouped():
return _forward_fn(moe_grouped, x)
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
f"{tflops_naive:.2f} TFLOP/s"
)
y_naive = run_naive()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
print(
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
print(
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
#!/usr/bin/env python
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import torch
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_int_list(value: str) -> List[int]:
return [int(v) for v in value.split(",") if v]
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument("--score-before", action="store_true")
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
p.add_argument("--route-norm", action="store_true")
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
module = module.to(device=device)
for param in module.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
for name, buf in module.named_buffers():
if name == "tokens_per_expert":
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
module._buffers[name] = buf.to(device=device, dtype=dtype)
module.eval()
return module
@torch.inference_mode()
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000.0)
return sum(timings) / len(timings)
@dataclass
class SweepResult:
bsz: int
seq: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def _run_case(
*,
bsz: int,
seq: int,
experts: int,
top_k: int,
hidden: int,
inter: int,
dtype: torch.dtype,
device: torch.device,
iters: int,
warmup: int,
init_std: float,
score_before: bool,
score_func: str,
route_norm: bool,
) -> SweepResult:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=True,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
moe_grouped.init_weights(init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=False,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
def run_naive():
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
return _forward(moe_naive, x)
def run_grouped():
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
return _forward(moe_grouped, x)
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
y_naive = run_naive()
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
return SweepResult(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
dtype=str(dtype),
naive_ms=naive_ms,
grouped_ms=grouped_ms,
speedup=speedup,
naive_tflops=naive_tflops,
grouped_tflops=grouped_tflops,
max_abs=max_abs,
mean_abs=mean_abs,
rel_l2=rel_l2,
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
def _print_result(res: SweepResult) -> None:
print(
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
fieldnames = [
"batch_size",
"seq_len",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"batch_size": r.bsz,
"seq_len": r.seq,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
def main() -> None:
args = _parse_args()
dtype = _map_dtype(args.dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_sizes = _parse_int_list(args.batch_sizes)
seq_lens = _parse_int_list(args.seq_lens)
experts_list = _parse_int_list(args.experts)
top_ks = _parse_int_list(args.top_ks)
results: List[SweepResult] = []
_print_header(args.hidden, args.inter, dtype, device)
for bsz in batch_sizes:
for seq in seq_lens:
for experts in experts_list:
for top_k in top_ks:
try:
res = _run_case(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
hidden=args.hidden,
inter=args.inter,
dtype=dtype,
device=device,
iters=args.iters,
warmup=args.warmup,
init_std=args.init_std,
score_before=args.score_before,
score_func=args.score_func,
route_norm=args.route_norm,
)
except RuntimeError as err:
print(
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
file=sys.stderr,
)
continue
results.append(res)
_print_result(res)
if args.csv and results:
_write_csv(args.csv, results)
print(f"Wrote {len(results)} rows to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not field_messages: if not field_messages:
raise ValueError( raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}' f"No conversation field found in dataset: {', '.join(feature_keys)}"
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not message_property_mappings["role"]: if not message_property_mappings["role"]:
raise ValueError( raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}' f"No role field found in messages: {', '.join(message_fields)}"
) )
for key in ["content", "text", "value"]: for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break break
if not message_property_mappings["content"]: if not message_property_mappings["content"]:
raise ValueError( raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}' f"No content field found in messages: {', '.join(message_fields)}"
) )
ds_cfg["message_property_mappings"] = message_property_mappings ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
) )

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()

View File

@@ -1,11 +1,10 @@
# noqa # noqa
# pylint: skip-file
import sys import sys
try: try:
import torch import torch
except ImportError: except ImportError as error:
raise ImportError("Install torch via `pip install torch`") raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:] use_uv = "--uv" in sys.argv[1:]

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 7): if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
_install_requires.append("xformers==0.0.30") _install_requires.append("xformers==0.0.30")
@@ -118,14 +120,14 @@ def get_package_version():
extras_require = { extras_require = {
"flash-attn": ["flash-attn==2.8.2"], "flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [ "ring-flash-attn": [
"flash-attn==2.8.2", "flash-attn==2.8.3",
"ring-flash-attn>=0.1.7", "ring-flash-attn>=0.1.7",
"yunchang==0.6.0", "yunchang==0.6.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.17.2", "deepspeed==0.17.5",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [
@@ -160,6 +162,7 @@ extras_require = {
"llmcompressor": [ "llmcompressor": [
"llmcompressor==0.5.1", "llmcompressor==0.5.1",
], ],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require extras_require

View File

@@ -4,5 +4,7 @@ import os
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging() configure_logging()

View File

@@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True) download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field( iterable: Optional[bool] = field(
default=None, default=False,
metadata={ metadata={
"help": "Use IterableDataset for streaming processing of large datasets" "help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
}, },
) )
@@ -111,6 +115,7 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None) quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None) group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None) output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass @dataclass

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art(): def print_axolotl_text_art():
"""Prints axolotl ASCII art.""" """Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement global HAS_PRINTED_LOGO
if HAS_PRINTED_LOGO: if HAS_PRINTED_LOGO:
return return
if is_main_process(): if is_main_process():

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None, cwd=None,
**kwargs, **kwargs,
) -> None: ) -> None:
cloud_cfg = load_cloud_cfg(cloud_config) cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg) provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file: with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read() config_yaml = file.read()
local_dirs = {} local_dirs = {}

View File

@@ -0,0 +1,48 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,71 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603 if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env cmd.split(), cwd=run_folder, env=new_env
): ):
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code)
# Commit writes to volume. # Commit writes to volume.
if volumes: if volumes:
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res return res
def get_image(self): def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0" docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag: if self.config.docker_tag:
docker_tag = self.config.docker_tag docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}" docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -130,7 +130,6 @@ class ModalCloud(Cloud):
res = [] res = []
if self.config.secrets: if self.config.secrets:
for key in self.config.get("secrets", []): for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str): if isinstance(key, str):
if val := os.environ.get(key, ""): if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val})) res.append(modal.Secret.from_dict({key: val}))
@@ -177,8 +176,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True): with self.app.run(detach=True):
modal_fn.remote( modal_fn.remote(
config_yaml, config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args, *args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs, **kwargs,
) )
@@ -187,7 +186,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout) return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements def get_train_gpu(self):
count = self.config.gpu_count or 1 count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s" family = self.config.gpu.lower() or "l40s"
@@ -200,7 +199,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]: if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count) return modal.gpu.A10G(count=count)
if family == "h100": if family == "h100":
return modal.gpu.H100(count=count) return f"H100:{count}"
if family == "t4": if family == "t4":
return modal.gpu.T4(count=count) return modal.gpu.T4(count=count)
if family == "l4": if family == "l4":
@@ -277,7 +276,7 @@ def _train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None, launcher_args: list[str] | None = None,
volumes=None, volumes=None,
**kwargs, # pylint: disable=unused-argument **kwargs,
): ):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True) Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:

View File

@@ -23,7 +23,8 @@ from axolotl.utils.config import (
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -210,7 +211,7 @@ def load_cfg(
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722 except:
gpu_version = None gpu_version = None
prepare_plugins(cfg) prepare_plugins(cfg)
@@ -227,8 +228,11 @@ def load_cfg(
}, },
) )
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
# have to wait for cfg.output to be resolved. We could call this earlier if we write
# to a temporary file, and then move it later.
prepare_debug_log(cfg)
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg) normalize_config(cfg)
normalize_cfg_datasets(cfg) normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
@@ -241,7 +245,6 @@ def load_cfg(
for k, v in cfg.items() for k, v in cfg.items()
if v is not None if v is not None
} }
LOG.info( LOG.info(
"config:\n%s", "config:\n%s",
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True), json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),

View File

@@ -28,7 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments. cli_args: CLI arguments.
""" """
# pylint: disable=duplicate-code
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token() check_user_token()
@@ -49,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -14,10 +14,12 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import ( from axolotl.cli.utils.diffusion import (
get_chat_template, diffusion_inference,
get_chat_template_from_config, launch_diffusion_gradio_ui,
) )
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -32,10 +34,11 @@ def get_multi_line_input() -> str:
Possibly multi-line, possibly empty stdin input as a string. Possibly multi-line, possibly empty stdin input as a string.
""" """
print("Give me an instruction (Ctrl + D to submit): ") print("Give me an instruction (Ctrl + D to submit): ")
print("=" * 80)
instruction = "" instruction = ""
for line in sys.stdin: for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join instruction += line
return instruction return instruction
@@ -46,9 +49,9 @@ def do_inference(
cli_args: InferenceCliArgs, cli_args: InferenceCliArgs,
): ):
""" """
Runs inference on the command line in a loop. User input is accepted, a chat template Runs inference on the command line in a loop. User input is accepted, a chat
is (optionally) applied, and the model specified in the `axolotl` config is used to template is (optionally) applied, and the model specified in the `axolotl` config is
generate completions according to a default generation config. used to generate completions according to a default generation config.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -64,17 +67,31 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template_from_config(
elif cfg.datasets[0].type == "chat_template": cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
) )
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
print("=" * 80)
print("Commands:")
print(":complete N -> completion mode with N tokens (default 64)")
print(":mask R -> random masking with ratio R (0.01.0)")
while True: while True:
print("=" * 80) print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
return return
@@ -104,9 +121,19 @@ def do_inference(
else: else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40) print("=" * 80)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
if is_diffusion:
diffusion_inference(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
)
continue
generation_config = GenerationConfig( generation_config = GenerationConfig(
repetition_penalty=1.1, repetition_penalty=1.1,
max_new_tokens=1024, max_new_tokens=1024,
@@ -129,7 +156,7 @@ def do_inference(
generation_config=generation_config, generation_config=generation_config,
streamer=streamer, streamer=streamer,
) )
print("=" * 40) print("=" * 80)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -159,15 +186,37 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
launch_diffusion_gradio_ui(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompter_module=prompter_module,
chat_template_str=chat_template_str,
)
return
def generate(instruction): def generate(instruction):
if not instruction: if not instruction:
return return
if prompter_module: if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next( prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n")) prompter_module().build_prompt(instruction=instruction.strip("\n"))
) )
@@ -252,7 +301,7 @@ def do_cli(
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs) parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs) parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -1,7 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os import os
import subprocess # nosec B404 import subprocess # nosec B404
from typing import Literal, Optional from typing import Literal, Optional
@@ -28,7 +26,7 @@ from axolotl.cli.utils import (
launch_training, launch_training,
) )
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -46,7 +44,7 @@ def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art() print_axolotl_text_art()
load_dotenv() load_dotenv()
patch_optimized_env() set_pytorch_cuda_alloc_conf()
@cli.command() @cli.command()

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
progressbar=True, progressbar=True,
) )
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor: if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading.""" """A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument def commit_tensor(self, read_item, tensor):
tensor.copy_(tensor.to(torch.bfloat16)) tensor.copy_(tensor.to(torch.bfloat16))
@@ -59,10 +59,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {} state_dict: Dict = {}
save_path_ = Path(save_path) save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True) save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access dist_cp_format_utils._load_state_dict(
state_dict, state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir), storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access planner=BFloat16CastPlanner(),
no_dist=True, no_dist=True,
) )
@@ -191,7 +191,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]: for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key): if cfg.get(key):
LOG.error( LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead." f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
) )
return return
@@ -73,7 +83,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained( AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True
) )
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841 except Exception: # nosec B110
pass pass
# fmt: on # fmt: on
@@ -95,9 +105,10 @@ def do_cli(
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1" os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs) is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from transformers import AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -43,13 +48,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file." "No quantization configuration found. Please specify either qat or quantization in your config file."
) )
model_path = cli_args.get("model_path") or cfg.output_dir model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"): if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype] weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else: else:
weight_dtype = quantize_cfg.weight_dtype weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"): if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype] activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else: else:
activation_dtype = quantize_cfg.activation_dtype activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -57,10 +62,15 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
) )
output_dir = cli_args.get("output_dir") or cfg.output_dir output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}...") LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info( LOG.info(
f"Quantizing model with configuration: \n" f"Quantizing model with configuration: \n"
@@ -70,11 +80,21 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}" f"\tquantize_embedding: {quantize_embedding}"
) )
quantize_model_for_ptq( quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding model, weight_dtype, group_size, activation_dtype, quantize_embedding
) )
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained( model.save_pretrained(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
safe_serialization=False, safe_serialization=False,
@@ -84,5 +104,16 @@ def do_quantize(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
safe_serialization=False, safe_serialization=False,
progressbar=True, progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
) )
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import prepare_optim_env
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -92,6 +92,7 @@ def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict) # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers # also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"]) cfg = DictDefault(kwargs["cfg"])
prepare_optim_env(cfg)
normalize_config(cfg) normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype

View File

@@ -65,7 +65,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type) field_type = _strip_optional_type(field.type)
if field_type == bool: if field_type is bool:
field_name = field.name.replace("_", "-") field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(
@@ -103,7 +103,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation) field_type = _strip_optional_type(field.annotation)
if field_type == bool: if field_type is bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(

View File

@@ -0,0 +1,374 @@
"""Helpers for diffusion-mode inference in CLI and Gradio."""
from __future__ import annotations
import gradio as gr
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
from axolotl.utils.dict import DictDefault
def diffusion_inference(
model,
tokenizer,
cfg,
prompt: str,
chat_template_str: str | None = None,
):
"""Diffusion inference helper method."""
mode = "random"
completion_tokens = 0
target_mask_ratio = None
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
if cleaned:
prompt = cleaned
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=mode,
target_mask_ratio=target_mask_ratio,
completion_tokens=completion_tokens,
)
masked_text = info["masked_text"]
mask_ratio = info["mask_ratio"]
generated_ids = info["generated_ids"]
masked_positions = info["masked_positions"]
orig_ids = info["orig_ids"]
# Display with masked preview and colored diff
if masked_text is not None and mask_ratio is not None:
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
if generated_ids is not None:
# Compute per-token style
styles: list[str] = []
for i, tid in enumerate(generated_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green") # correct fill
elif i < len(orig_ids):
styles.append("red") # incorrect fill
else:
styles.append("normal") # appended
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
# Group contiguous spans by style
styled_spans: list[tuple[str, int, int]] = []
if generated_ids:
current_style = styles[0]
start = 0
for i in range(1, len(generated_ids)):
s = styles[i]
if s != current_style:
styled_spans.append((current_style, start, i))
current_style, start = s, i
styled_spans.append((current_style, start, len(generated_ids)))
out_parts = []
for style_name, a, b in styled_spans:
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
out_parts.append(chunk_text)
print("Generated:\n" + "".join(out_parts))
else:
print("Generated:\n(no output)")
def _parse_commands(text: str):
"""
Parse leading diffusion commands.
Supported at start of input (can be chained):
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R in [0, 1]
"""
tokens = text.strip().split()
i = 0
mode = "random"
completion_tokens = 0
target_mask_ratio = None
consumed = 0
while i < len(tokens) and tokens[i].startswith(":"):
cmd = tokens[i]
i += 1
consumed = i
if cmd == ":complete":
mode = "completion"
if i < len(tokens):
try:
completion_tokens = int(tokens[i])
i += 1
consumed = i
except Exception:
completion_tokens = 64
else:
completion_tokens = 64
elif cmd == ":mask":
mode = "random"
if i < len(tokens):
try:
target_mask_ratio = float(tokens[i])
i += 1
consumed = i
except Exception:
target_mask_ratio = None
else:
i -= 1
consumed = i
break
cleaned = " ".join(tokens[consumed:])
return mode, completion_tokens, target_mask_ratio, cleaned
def run_diffusion(
*,
model,
tokenizer,
cfg: DictDefault,
prompt: str,
chat_template_str: str | None,
mode: str = "random",
target_mask_ratio: float | None = None,
completion_tokens: int = 0,
):
"""Run a single diffusion generation and return a structured result dict."""
if chat_template_str:
batch = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
seq = batch["input_ids"].to(cfg.device)
gen_mode = "completion" if mode == "completion" else "random"
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
result = generate(
model,
tokenizer,
original_sequence=seq[:1],
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
temperature=cfg.diffusion.generation_temperature,
mask_token_id=int(mask_token_id),
mode=gen_mode, # type: ignore[arg-type]
completion_tokens=comp_tokens,
target_mask_ratio=target_mask_ratio,
)
masked_text = result.get("masked") if isinstance(result, dict) else None
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
masked_positions = (
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
)
orig_ids = seq[0].detach().cpu().tolist()
return {
"masked_text": masked_text,
"mask_ratio": mask_ratio,
"generated_ids": generated_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids,
}
def render_html(
*,
generated_ids: list[int] | None,
orig_ids: list[int],
masked_positions: set[int],
tokenizer,
) -> str:
"""Render HTML visualizing diffusion outputs."""
if not generated_ids:
return "<pre>Generated:\n(no output)</pre>"
def _style_for(i: int, tid: int) -> str:
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
return "green"
if i < len(orig_ids):
return "red"
return "normal"
same = i < len(orig_ids) and tid == orig_ids[i]
return "dim" if same else "normal"
# Group contiguous spans by style to reduce HTML size
spans: list[tuple[str, int, int]] = []
if generated_ids:
cur = _style_for(0, generated_ids[0])
start = 0
for i in range(1, len(generated_ids)):
s = _style_for(i, generated_ids[i])
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(generated_ids)))
html_parts = []
for style_name, a, b in spans:
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
elif style_name == "red":
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
elif style_name == "dim":
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
else:
html_parts.append(txt)
legend = (
'<div style="font-size:0.9em;margin-bottom:4px">'
'<span style="color:#2e7d32">correct</span>, '
'<span style="color:#c62828">incorrect</span>, '
'<span style="opacity:0.6">unchanged</span>'
"</div>"
)
return (
legend
+ '<pre style="white-space:pre-wrap">Generated:\n'
+ "".join(html_parts)
+ "</pre>"
)
def launch_diffusion_gradio_ui(
*,
model,
tokenizer,
cfg: DictDefault,
prompter_module=None,
chat_template_str: str | None = None,
):
"""Build and launch a simple Gradio UI for diffusion inference."""
with gr.Blocks(
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
) as demo:
gr.Markdown(
"""
## Axolotl Diffusion Inference
- Mode "Random" masks tokens at a target ratio and fills them.
- Mode "Completion" appends N masked tokens at the end and fills them.
"""
)
with gr.Row():
mode = gr.Radio(
choices=["random", "completion"],
value="random",
label="Mode",
)
mask_ratio = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.4,
label="Mask ratio (random mode)",
interactive=True,
)
completion_tokens = gr.Number(
value=64,
precision=0,
label="Completion tokens (completion mode)",
interactive=True,
visible=False,
)
instruction = gr.Textbox(label="Instruction", lines=6)
run_btn = gr.Button("Generate")
masked_preview = gr.Textbox(label="Masked preview", lines=6)
html_out = gr.HTML(label="Generated")
def _toggle_controls(selected_mode: str):
return (
gr.update(visible=(selected_mode == "random")),
gr.update(visible=(selected_mode == "completion")),
)
mode.change(
_toggle_controls,
inputs=[mode],
outputs=[mask_ratio, completion_tokens],
)
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
if not instruction_text:
return "", "<pre>Generated:\n(no output)</pre>"
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(
instruction=instruction_text.strip("\n")
)
)
else:
prompt = instruction_text.strip()
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=selected_mode,
target_mask_ratio=mratio if selected_mode == "random" else None,
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
)
masked_text = info.get("masked_text")
mask_ratio_val = info.get("mask_ratio")
generated_ids = info.get("generated_ids")
masked_positions = info.get("masked_positions") or set()
orig_ids = info.get("orig_ids") or []
preview = (
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
if masked_text is not None and mask_ratio_val is not None
else ""
)
html = render_html(
generated_ids=generated_ids,
orig_ids=orig_ids,
masked_positions=masked_positions,
tokenizer=tokenizer,
)
return preview, html
run_btn.click(
_gen,
inputs=[instruction, mode, mask_ratio, completion_tokens],
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)

View File

@@ -3,11 +3,12 @@
import random import random
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from typing import Any
def generate_sweep_configs( def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list] base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]: ) -> list[dict[str, Any]]:
""" """
Recursively generates all possible configurations by applying sweeps to the base config. Recursively generates all possible configurations by applying sweeps to the base config.
@@ -48,7 +49,10 @@ def generate_sweep_configs(
new_config = {} new_config = {}
# new_config = deepcopy(base_config) # new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters # Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
for param_name, param_value in full_combo.items(): for param_name, param_value in full_combo.items():
new_config[param_name] = param_value new_config[param_name] = param_value
print(new_config) print(new_config)
@@ -57,7 +61,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations # If no paired values, just use regular combinations
# new_config = deepcopy(base_config) # new_config = deepcopy(base_config)
new_config = {} new_config = {}
for param_name, param_value in zip(param_names, reg_combo): for param_name, param_value in zip(param_names, reg_combo, strict=False):
new_config[param_name] = param_value new_config[param_name] = param_value
print(new_config) print(new_config)
all_combinations.append(new_config) all_combinations.append(new_config)

View File

@@ -4,6 +4,7 @@ import os
import subprocess # nosec import subprocess # nosec
import sys import sys
import tempfile import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal from typing import Any, Iterator, Literal
import yaml import yaml
@@ -88,8 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations # Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config) permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1 is_group = len(permutations) > 1
for permutation in permutations: base_output_dir = base_config.get("output_dir", "./model-out")
# pylint: disable=consider-using-with for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
temp_file = tempfile.NamedTemporaryFile( temp_file = tempfile.NamedTemporaryFile(
mode="w", mode="w",
suffix=".yaml", suffix=".yaml",

View File

@@ -39,7 +39,7 @@ def do_vllm_serve(
model = cfg.base_model model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1 tensor_parallel_size = 1
data_parallel_size = 1 data_parallel_size = 1
@@ -68,7 +68,6 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
) )
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments( vllm_script_args = AxolotlScriptArguments(
model=model, model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass
from datasets import Dataset from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
@@ -55,13 +55,11 @@ def load_datasets(
""" """
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets( train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, cfg,
tokenizer, tokenizer,
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
if ( if (

View File

@@ -67,9 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer self.jsonl_serializer = jsonl_serializer
def convert( def convert(self, input_file_path, output_file_path):
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
content = self.file_reader.read(input_file_path) content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content) data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -84,9 +84,7 @@ def create_causal_mask(
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None: if attention_mask is not None:
def causal_doc_mask_mod( def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
""" """
Defines the logic of a block causal mask by combining both a standard causal mask Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask. and a block diagonal document mask.
@@ -103,9 +101,7 @@ def create_causal_mask(
mask_factory_function = causal_doc_mask_mod mask_factory_function = causal_doc_mask_mod
else: else:
mask_factory_function = causal_mask_function mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC) # Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = ( allow_is_causal_skip = (

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any from typing import Any
import torch import torch
from transformers import ( from transformers import TrainerCallback
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
@@ -44,7 +42,7 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
with suppress(ImportError): with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
@@ -260,14 +258,14 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon": if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module from axolotl.contribs.mit.muon import (
MuonOptimizerFactory, MuonOptimizerFactory,
) )
optimizer_cls = MuonOptimizerFactory optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion": elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module from axolotl.contribs.mit.dion import (
DionOptimizerFactory, DionOptimizerFactory,
) )
@@ -414,12 +412,8 @@ class TrainerBuilderBase(abc.ABC):
def _configure_torch_compile(self, training_args_kwargs: dict): def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None): if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access torch._dynamo.config.suppress_errors = True
True torch._dynamo.config.accumulated_cache_size_limit = 256
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = ( training_args_kwargs["torch_compile_backend"] = (
@@ -441,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
# don't use the HF gradient checkpointing, manually wrap # don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing: elif self.cfg.gradient_checkpointing is not None:
training_args_kwargs["gradient_checkpointing"] = ( training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing self.cfg.gradient_checkpointing
) )
@@ -516,6 +510,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size self.cfg.eval_batch_size
) )
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
@@ -75,6 +76,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat: if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat)) callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -341,20 +348,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model: elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig training_args_cls = AxolotlPRMConfig
else: else:
training_args_cls = AxolotlTrainingArguments training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls(
**training_arguments_kwargs, **training_arguments_kwargs,
) )
training_args = self.hook_post_create_training_args(training_args) training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names # unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init training_args.run_name = None
None
)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, # True/"longest" is the default "padding": True, # True/"longest" is the default
@@ -408,6 +417,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs, **trainer_kwargs,
) )
trainer = self.hook_post_create_trainer(trainer) trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer): for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback) trainer.add_callback(callback)

View File

@@ -168,16 +168,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if plugin_training_args: if plugin_training_args:
training_args_kwargs.update(plugin_training_args) training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls(
logging_first_step=True, logging_first_step=True,
**training_args_kwargs, **training_args_kwargs,
) )
# unset run_name so wandb sets up experiment names # unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init training_args.run_name = None
None
)
return training_args, trainer_kwargs return training_args, trainer_kwargs

View File

@@ -10,7 +10,7 @@ from .shared import wrap_tools
def format_message( def format_message(
message: Messages, message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument message_index: Optional[int] = None,
) -> Messages: ) -> Messages:
if message.is_chat_formatted: if message.is_chat_formatted:
return message return message

View File

@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
Message roles for the system, user, assistant, and tools Message roles for the system, user, assistant, and tools
""" """
system = "system" # pylint: disable=invalid-name system = "system"
user = "user" # pylint: disable=invalid-name user = "user"
assistant = "assistant" # pylint: disable=invalid-name assistant = "assistant"
tool = "tool" # pylint: disable=invalid-name tool = "tool"
ipython = ( # pylint: disable=invalid-name ipython = (
# for responses from builtin tools # for responses from builtin tools
"ipython" "ipython"
) )
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
Message content types for text, image, audio, tool calls, and tool responses Message content types for text, image, audio, tool calls, and tool responses
""" """
special_token = "special_token" # pylint: disable=invalid-name # nosec B105 special_token = "special_token" # nosec B105
text = "text" # pylint: disable=invalid-name text = "text"
image = "image" # pylint: disable=invalid-name image = "image"
audio = "audio" # pylint: disable=invalid-name audio = "audio"
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant tool_call = "tool_call"
tool_response = "tool_response" # pylint: disable=invalid-name tool_response = "tool_response"
class SpecialToken(str, Enum): class SpecialToken(str, Enum):
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
Special tokens for beginning of string and end of string Special tokens for beginning of string and end of string
""" """
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 bos_token = "bos_token" # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 eos_token = "eos_token" # nosec B105
class ToolCallFunction(BaseModel): class ToolCallFunction(BaseModel):
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
name: str name: str
arguments: dict[str, Union[str, int]] arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name id: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments} data = {"name": self.name, "arguments": self.arguments}
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
name: str name: str
content: Union[str, dict[str, Union[str, int, float]]] content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name id: Optional[str] = None
def __str__(self) -> str: def __str__(self) -> str:
data = {"name": self.name, "content": self.content} data = {"name": self.name, "content": self.content}

View File

@@ -1,23 +1,17 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the
dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping
def chat_message_transform_builder( # pylint: disable=dangerous-default-value def chat_message_transform_builder(
train_on_inputs=False, train_on_inputs=False,
conversations_field: str = "conversations", conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: Union[str, list[str]] = [ message_field_content: str | list[str] | None = None, # commonly "content"
"value", message_field_training: str | list[str] | None = None, # commonly "weight"
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
): ):
"""Builds a transform that takes a row from the dataset and converts it to a Chat """Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -39,6 +33,12 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value
A function that takes a list of conversations and returns a list of messages. A function that takes a list of conversations and returns a list of messages.
""" """
if message_field_training is None:
message_field_training = ["train", "weight"]
if message_field_content is None:
message_field_content = ["value", "text", "content"]
if message_field_role is None:
message_field_role = ["role", "from"]
message_field_role = ( message_field_role = (
[message_field_role] [message_field_role]
if isinstance(message_field_role, str) if isinstance(message_field_role, str)

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers""" """Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .base import AxolotlTrainer from .base import AxolotlTrainer

View File

@@ -1,7 +1,5 @@
"""Module for customized trainers""" """Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations from __future__ import annotations
import os import os
@@ -44,12 +42,20 @@ from axolotl.core.trainers.utils import (
) )
from axolotl.utils import get_not_null from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__) LOG = get_logger(__name__)
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
"max": torch.max,
"sum": torch.sum,
}
class AxolotlTrainer( class AxolotlTrainer(
PackingMixin, PackingMixin,
@@ -65,6 +71,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"] tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__( def __init__(
self, self,
@@ -80,7 +95,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict( self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -274,18 +288,6 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers, num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index, rank=self.args.process_index,
) )
if (self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and ( if self.args.sample_packing and (
(is_training and not self.args.pretraining) (is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False) or (not is_training and self.args.eval_sample_packing is not False)
@@ -299,9 +301,9 @@ class AxolotlTrainer(
# fmt: off # fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers: if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"): if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
else: else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init self._eval_dataloaders = {dataloader_key: dataloader}
# fmt: on # fmt: on
return self.accelerator.prepare(dataloader) return self.accelerator.prepare(dataloader)
@@ -343,6 +345,17 @@ class AxolotlTrainer(
# outputs = model(**inputs) # outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(
model, model,
@@ -358,6 +371,11 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
@override
def evaluate(self, *args, **kwargs):
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}
@@ -457,7 +475,7 @@ class AxolotlTrainer(
model, model,
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument num_items_in_batch=None,
): ):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs, inputs,
@@ -538,15 +556,10 @@ class AxolotlTrainer(
accelerator_config = self.args.accelerator_config.to_dict() accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False) use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state: if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access AcceleratorState._reset_state(reset_partial_state=True)
reset_partial_state=True
)
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
if ( if (
"limit_all_gathers" in self.args.fsdp_config "limit_all_gathers" in self.args.fsdp_config
@@ -554,7 +567,6 @@ class AxolotlTrainer(
): ):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True self.accelerator.state.fsdp_plugin.limit_all_gathers = True
# pylint: disable=unused-argument
def additional_accelerator_args( def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -588,36 +600,34 @@ class AxolotlTrainer(
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
# Add reduced stored metrics to logs
for key, metric_data in self._stored_metrics[train_eval].items(): for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
reduction_type = metric_data["reduction"] reduction_type = metric_data["reduction"]
if reduction_type == "mean": fn = REDUCTION_FNS.get(reduction_type)
logs[key] = values.mean().item() if fn is None:
elif reduction_type == "min":
logs[key] = values.min().item()
elif reduction_type == "max":
logs[key] = values.max().item()
elif reduction_type == "sum":
logs[key] = values.sum().item()
else:
raise NotImplementedError( raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]" "Metric reduction must be one of [mean, min, max, sum]"
) )
logs[key] = round(fn(values).item(), 4)
logs[key] = round(logs[key], 4)
if is_main_process(): if is_main_process():
# Add memory usage # Add memory usage
try: try:
active, allocated, reserved = get_gpu_memory_usage() active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2) logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2) logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2) logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError): except (ValueError, TypeError, FileNotFoundError):
pass pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time) return super().log(logs, start_time)
@@ -638,12 +648,12 @@ class AxolotlTrainer(
""" """
for key, value in metrics.items(): for key, value in metrics.items():
if isinstance(value, tuple): if isinstance(value, tuple):
metric_value, metric_reduction = value value, _reduction = value # type: ignore[assignment]
else: else:
metric_value, metric_reduction = value, reduction value, _reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(metric_value) self._stored_metrics[train_eval][key]["values"].append(value)
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction self._stored_metrics[train_eval][key]["reduction"] = _reduction
def _save_checkpoint(self, model, trial, **kwargs): def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey
@@ -710,6 +720,11 @@ class AxolotlTrainer(
LOG.info( LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
) )
self.data_collator.tokenizer.save_pretrained(output_dir) save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -101,11 +101,11 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss: if self.args.dpo_norm_loss:
# fmt: off # fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on # fmt: on
# concatenated_forward handles avg token logprob for ipo case already # concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model) res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init self.loss_type = loss_type
return res return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model) return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -128,9 +128,7 @@ class GRPOStrategy:
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
def set_trainer_args( def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
trainer_args = [] trainer_args = []
if cfg.trl and cfg.trl.reward_funcs: if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = [] reward_funcs = []
@@ -151,7 +149,7 @@ class GRPOStrategy:
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument def get_collator(cls, *args, **kwargs):
# No data collation is needed in GRPO, handled by trl's trainer __init__ # No data collation is needed in GRPO, handled by trl's trainer __init__
return None return None

View File

@@ -1,7 +1,5 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)""" """Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings import warnings
from functools import partial from functools import partial
from typing import Any from typing import Any
@@ -52,7 +50,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.monkeypatch.ring_attn import get_ring_attn_group from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available(): if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig from peft import PeftConfig
@@ -253,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training""" """Get dataloader for training"""
train_dataset = self.train_dataset train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing # Handle dataset preprocessing
@@ -266,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
train_dataset, description="training" train_dataset, description="training"
) )
else: else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init self.data_collator = self._get_collator_with_removed_columns(
data_collator, data_collator,
description="training", description="training",
) )
@@ -308,10 +305,10 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.args.use_vllm: if self.args.use_vllm:
# First, have main process load weights if needed # First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm() self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
@@ -333,8 +330,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Extract prompts from this SP group, accounting for num_generations duplicates # Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group # We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[ group_prompts = all_prompts_text[
group_leader_rank group_leader_rank * len(prompts_text) : (
* len(prompts_text) : (group_leader_rank + 1) group_leader_rank + 1
)
* len(prompts_text) : self.num_generations * len(prompts_text) : self.num_generations
] ]
@@ -485,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
) )
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
completions = [] completions = []
for prompt, completion in zip(prompts, completions_text): for prompt, completion in zip(prompts, completions_text, strict=False):
bootstrap = ( bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
) )
@@ -503,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.reward_funcs, self.reward_funcs,
self.reward_processing_classes, self.reward_processing_classes,
self.reward_func_names, self.reward_func_names,
strict=False,
) )
): ):
with profiling_context(self, reward_func_name): with profiling_context(self, reward_func_name):
@@ -511,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): # Module instead of PretrainedModel for compat with compiled models ): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]): if is_conversational(inputs[0]):
messages = [ messages = [
{"messages": p + c} for p, c in zip(prompts, completions) {"messages": p + c}
for p, c in zip(prompts, completions, strict=False)
] ]
texts = [ texts = [
apply_chat_template(x, reward_processing_class)["text"] apply_chat_template(x, reward_processing_class)["text"]
for x in messages for x in messages
] ]
else: else:
texts = [p + c for p, c in zip(prompts, completions)] texts = [
p + c for p, c in zip(prompts, completions, strict=False)
]
reward_inputs = reward_processing_class( reward_inputs = reward_processing_class(
text=texts, text=texts,
return_tensors="pt", return_tensors="pt",
@@ -564,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
row_reward_kwargs["completion"] = completions[nan_row_idx] row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn( warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward." "Please ensure that at least one reward function returns a valid reward.",
stacklevel=2,
) )
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the

View File

@@ -5,7 +5,6 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation""" """Mamba specific trainer to handle loss calculation"""
@@ -15,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
self, self,
model, model,
inputs, inputs,
return_outputs=False, # pylint: disable=unused-argument return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument num_items_in_batch=None,
): ):
input_ids = inputs.pop("input_ids") input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits lm_logits = model(input_ids).logits

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers.mixins""" """Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin from .activation_checkpointing import ActivationOffloadingMixin

View File

@@ -92,7 +92,7 @@ def get_lora_act_offloading_ctx_manager(
`contextlib.ContextDecorator`: `contextlib.ContextDecorator`:
Activation offloading context manager for the model. Activation offloading context manager for the model.
""" """
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations( activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory, use_pin_memory=use_pin_memory,
use_streams=use_streams, use_streams=use_streams,

Some files were not shown because too many files have changed in this diff Show More