Compare commits

...

74 Commits

Author SHA1 Message Date
Wing Lian
a4a3b618e7 force torch to match when installing fa and deepspeed using uv 2026-03-04 10:00:08 -05:00
Wing Lian
b6b8db805a fix python version typo for building 3.11 (#3454) 2026-03-04 09:53:35 -05:00
Wing Lian
653f90be25 Add torch 2.10.0 to unit tests and use python 3.14 (#3450)
* Add torch 2.10.0 to unit tests and use python 3.14

* hold on python 3.14 checks due to mistral common

* add base option to matrix
2026-03-03 13:01:52 -05:00
NanoCode012
945c8aeb10 Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict

* fix: apply fix for only CP mode

* fix: add dropout check when using lora target param

* fix: re-add patch from transformers PR #39866

* feat: add moe quant to test by ved

* fix: try match target param properly end with

* fix: clear cache per param quant

* fix: attempt on-load quantize experts instead of post-load

* fix: attempt disable async load

* chore: add log

* chore: adjust log

* fix: remove cuda alloc for moe and enable async load

* chore: remove leftover logs

* chore: add extra empty cache

* fix(doc): clarify support

* fix: handle fsdp2 for paramwrapper dtensor

* feat: attempt to quant experts in 8bit mode too

* feat: attempt to release bf16 experts from vram

* feat: upgrade cce

* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init

* fix: remove unnecessary gc and empty cache

* Revert "fix: remove unnecessary gc and empty cache"

This reverts commit 1d54518990.

* fix: do not call full_tensor on non-dtensors

* fix: attempt to address fsdp2 with quant exp high loss

* fix: attempt lora quant experts wrong dim

* fix: ensure require_grad patch applied for lora 8bit

* fix: attempt lora 8bit fsdp2

* fix: attribute access on save for lora 8bit fsdp2

* fix: wrong weight attrib access

* chore(refactor): add config, re-arrange position of patches, clean
comments

* feat: add example docs

* chore: cherry pick trinity fixes from PR 3399

* chore: comments refactor; add guards

* fix: guard using wrong key

* fix: mamba save does not accept main process param

* fix: guard prevent double hook

* fix: move gc to upper scope

* chore: add comment on proxy forward patch

* fix: add comment to clarify

* feat: add test idempotency

* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter

* fix: AttributeError: 'NoneType' object has no attribute 'to'

* fix: update docs on cpu_ram_efficient_loading
2026-03-03 10:06:23 -05:00
NanoCode012
e672d37f33 fix: qwen3-next to use fla causal-conv1d to support packing (#3437
* fix: qwen3-next to use fla causal-conv1d to support packing

* fix: causal import and update doc for v5

* fix: hard fail for packing without fla
2026-03-03 09:26:46 -05:00
Wing Lian
77828d3559 uv cloud image should use uv w pip (#3449) 2026-03-02 16:39:26 -05:00
Wing Lian
4272817109 don't install torch ao on arm64 (#3448) 2026-03-02 14:24:54 -05:00
Manas Vardhan
474208b794 fix: Save de-duplicated dataset during pre-processing (#3427)
* fix: run deduplication before saving dataset during preprocessing

Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.

Fixes #2719

* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup

- Add dataset_exact_deduplication to the hash string in
  generate_dataset_hash_from_config so cached datasets are invalidated
  when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
  dataset_exact_deduplication=True, since dedup will be silently
  skipped in that configuration (both SFT and RL paths).

* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting

- Add config validator (check_deduplication_with_skip_prepare) that raises
  ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
  axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues

* refactor: inline deduplicate function per review feedback

* fix test fixture, lint

---------

Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-03-02 12:55:59 -05:00
Wing Lian
444020b332 mark slow tests that are timing out in CI (#3428) [skip ci] 2026-03-02 12:26:30 -05:00
Wing Lian
aa88c2e30b fix uv cache subcommand (#3447) 2026-03-02 12:26:08 -05:00
NanoCode012
f447bce1db fix: do not push telemetry on non-master rank (#3438) 2026-03-02 15:31:20 +07:00
kallewoof
7f23b302d1 bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() (#3435) [skip ci]
* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler()

* nit: raise if self.optimizer is also unset

* optimizer properly optional in create_scheduler()
2026-03-02 15:30:07 +07:00
Wing Lian
18f26c19ef add uv axolotl builds (#3431) 2026-02-25 14:46:02 -05:00
Robert Ronan
2b6f4a6c9b Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-02-25 11:31:11 +07:00
madScientist10
8f54b4eb25 fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders

* fix: address revision=None passed to .from_pretrained

* add tests and address review feedback for revision parameter

- Reformat modify_tokenizer_files signature and from_pretrained call
- Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision
- Add 6 unit tests for revision parameter in tokenizer/processor loaders

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-25 11:11:20 +07:00
VED
a131e4d0e5 sample gen support sft (#3240) [skip ci]
* add:parameters + callback

* sft core + logging

* indentation fix

* logger fix

* loger fix in sft

* gen sample on eval

* lint

* deprecation
2026-02-25 11:10:57 +07:00
Wing Lian
1791d87b6f build axolotl images with torch 2.10.0 (#3430) 2026-02-24 22:35:25 -05:00
Wing Lian
b40803da51 build base images for torch 2.10.0 (#3429) 2026-02-24 20:32:34 -05:00
Wing Lian
68f1b7004c ScatterMoE LoRA support (#3410)
* scattermoe lora support

* fsdp, bf16, dim fixes

* expert weights aren't needed in save for bwd since they are frozen

* use sonicmoe optim options

* update save model from upstream

* fixes per code review feedback and add tests

* revert removal of CP fix

* misc fixes
2026-02-24 14:59:55 -05:00
NanoCode012
08441fed17 fix: set allowed values for adapter config (#3415) 2026-02-23 11:39:53 -05:00
NanoCode012
86ca1e27c0 fix: update MistralProcessor to be v5 compat (#3423)
* fix: update MistralProcessor to be v5 compat

* feat: add test for mistral3 processor

* chore: comment
2026-02-23 11:39:13 -05:00
Manas Vardhan
5ed455715e feat: support dot-notation CLI args for nested config options (#3419)
* feat: support dot-notation CLI args for nested config options

Add support for overriding nested config fields (like TRL config) via
CLI using dot-notation, e.g.:
  axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1

Changes:
- args.py: Detect BaseModel subclass fields and generate dot-notation
  CLI options (--parent.child) that map to double-underscore kwargs
  (parent__child). Also fix _strip_optional_type for Python 3.10+
  union syntax (X | None).
- config.py: Handle double-underscore kwargs in load_cfg by setting
  nested dict values on the config.
- Add tests for nested option handling.

Fixes #2702

* Address CodeRabbit review: fix string parent bug, add type hints and docstring

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>

* Add type coercion for CLI kwargs and fix pre-commit issues

- Add _coerce_value() for YAML-style type inference on string CLI args
- When existing config value has a type (int/float/bool), cast to match
- When no existing value, infer type from string (true/false, ints, floats, null)
- Apply coercion to both flat and nested (dot-notation) kwargs
- Fix unused pytest import (pre-commit/ruff)
- Update tests to pass string values (matching real CLI behavior)
- Add dedicated TestCoerceValue test class

Addresses maintainer feedback on type casting for nested kwargs.

---------

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
2026-02-23 10:10:06 -05:00
Lorenzo Baraldi
3f30572d4a Fix typo in dataset_processes field (#3426)
* Fix typo in dataset_processes field

* fix: use updated config name

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-23 14:18:37 +07:00
NanoCode012
43d60c7439 bump cut-cross-entropy to 58d6572 (#3424) 2026-02-20 14:24:51 -05:00
Wing Lian
0ea252d392 update to trackio 0.16.1 (#3425) [skip ci] 2026-02-20 14:24:33 -05:00
Wing Lian
29722dec60 use bunnycdn for CI assets (#3422) [skip ci] 2026-02-20 00:09:25 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
Wing Lian
236dad3bb7 set 0.15.0.dev0 version (#3380) 2026-01-30 21:28:01 -05:00
Wing Lian
be00978bc2 tag for v0.14.0 release (#3379)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-30 14:10:27 -05:00
Wing Lian
3738978394 Add support for batched_mm, grouped_mm and scattermoe for MoE models (#3377)
* kernels plugin for moe for v5

* add support for native batched_mm or grouped_mm
2026-01-29 14:25:47 -05:00
Wing Lian
6132a30cda handle warnings from v5 upgrade (#3376) 2026-01-28 06:45:01 -05:00
NanoCode012
3dd86d35b8 feat: add new cce support for glm series and exaone4 (#3373) [skip ci] 2026-01-28 06:44:44 -05:00
salman
dd9ebaeba1 EAFT (#3366) [skip ci]
* wip eaft

* fix eaft loss fn

* adding ref

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2026-01-28 06:44:15 -05:00
Wing Lian
fc4e37920b transformers v5 upgrade (#3272)
* Prepare for transformers v5 upgrade

* fix hf cli

* update for hf hub changes

* fix tokenizer apply_chat_template args

* remap include_tokens_per_second

* fix tps

* handle migration for warmup

* use latest hf hub

* Fix scan -> ls

* fix import

* fix for renaming of mistral common tokenizer -> backend

* update for fixed tokenziation for llama

* Skip phi35 tests for now

* remove mistral patch fixed upstream in huggingface/transformers#41439

* use namespacing for patch

* don't rely on sdist for e2e tests for now

* run modal ci without waiting too

* Fix dep for ci

* fix imports

* Fix fp8 check

* fsdp2 fixes

* fix version handling

* update fsdp version tests for new v5 behavior

* Fail multigpu tests after 3 failures

* skip known v5 broken tests for now and cleanup

* bump deps

* unmark skipped test

* re-enable test_fsdp_qlora_prequant_packed test

* increase multigpu ci timeout

* skip broken gemma3 test

* reduce timout back to original 120min now that the hanging test is skipped

* fix for un-necessary collator for pretraining with bsz=1

* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)

* torch_dtype deprecated

* load model in float32 for consistency with tests

* revert some test fixtures back

* use hf cache ls instead of scan

* don't strip fsdp_version

more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places

* Transformers v5 rc2 (#3347)

* bump dep

* use latest fbgemm, grab model config as part of fixture, un-skip test

* import AutoConfig

* don't need more problematic autoconfig when specifying config.json manually

* add fixtures for argilla ultrafeedback datasets

* download phi4-reasoning

* fix arg

* update tests for phi fast tokenizer changes

* use explicit model types for gemma3

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>

* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText

* chore: remove duplicate

* fix: attempt fix gemma3 text mode

* chore: lint

* ga release of v5

* need property setter for name_or_path for mistral tokenizer

* vllm not compatible with transformers v5

* setter for chat_template w mistral too

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2026-01-27 17:08:24 -05:00
Wing Lian
a531e9d946 upgrade vllm to v0.14.0 (#3345) 2026-01-21 20:00:18 -05:00
Wing Lian
04328aeb97 cu129 targets for ci builds (#3369)
* cu129 targets for ci builds

* remove copy-paste is_latest
2026-01-21 17:24:44 -05:00
VED
d0d26d5064 feat: Add GDPO Support (#3353)
* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-21 17:22:45 -05:00
Wing Lian
8623dd8a72 strip only starting 'v' char; e.g don't strip from '.dev' (#3368) [skip ci] 2026-01-21 14:19:03 -05:00
Wing Lian
8cd75cff9f use cuda 12.9.1 and add python 3.12 to base images (#3367) 2026-01-21 13:34:14 -05:00
Wing Lian
8ab9d9ea88 Version dev (#3365) 2026-01-20 22:58:29 -05:00
Wing Lian
6e42def14b set version to v0.13.1 (#3363)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-20 08:58:32 -05:00
Wing Lian
c413480b35 upgrade transformers to 4.57.6 and peft to 0.17.1 and datasets to 4.5.0 (#3361) 2026-01-16 11:48:50 -05:00
Wing Lian
8f25124269 upgrade transformers to 4.57.5 (#3358)
* upgrade transformers to 4.57.5

* explicitly set versions for fbgemm-gpu

* handle index url for cuda version

* explicitly set cu version for fbgemm deps, skip for 130

* cu suffix not needed on version if using whl subpath
2026-01-16 11:17:43 -05:00
Wing Lian
790df757cb don't install xformers in for arm64 (#3359)
* install xformers in the base docker image

* install numba and numpy first

* set CUDA_HOME for xformers install

* Set cuda  home env

* don't install xformers by default on aarch64/arm64
2026-01-16 09:02:37 -05:00
Wing Lian
d282f32481 don't install deepspeed in arm64 images (#3357) 2026-01-14 12:03:55 -05:00
Wing Lian
6331e4a130 fix amd64 and set 2.9.1 as latest cloud image (#3356) 2026-01-14 11:56:36 -05:00
salman
1410e4474e update PR template (#3349) [skip ci] 2026-01-14 09:39:21 -05:00
Wing Lian
dc77b5bf42 fix arm64 builds (#3355)
* fix syntax  for secrets in gha yaml

* setup env for uv too

* arm64 for base  uv too

* don't build causal-conv1d or mamba for arm64 and use arm64 wheels

* fix dockerfile syntax

* fix shell syntax
2026-01-14 09:38:48 -05:00
NanoCode012
359b7ad85e fix: gemma3_text model loading vision config (#3354)
* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
2026-01-13 09:49:23 -05:00
VED
258ce8d4fa feat : scaled softmax support (#3338)
* scaled softmax

* comment

* lint

* remove egear

* validation for flash

* lint

* val imporve + neet

* fix correct softmax scale val(learned)

* learned scale val 4 ssm

* lint

* fix model_type rmv

* sdpa_atten

* test fix + lint

* test fix

* sdp_a val rmv

* flex fix

* main flash

* lint

* flex attn

* lint comment

* fix score_mod

* Update src/axolotl/utils/schemas/validation.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-13 14:33:11 +07:00
@TT
3e0bbd33ec feat: add ARM64/AArch64 build support to Dockerfile-base (#3346)
* Add support for capability to build arm64 image

* Fixing wrong variable TARGETPLATFORM bug

* Adding missing semicolons

* skip docker hub login if PR (no push) or no credentials

* Enabling arm64 builds for Dockerfile-base in Github actions

* TARGETARCH automatically default to platform arch under build

* Enabling arm64 builds for axolotl docker builds

* Enabling arm64 builds for axolotl-cloud docker build Github actions

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-12 12:00:02 -05:00
salman
4ae6f766ad bump bnb to v0.49.1 (#3351) 2026-01-12 09:42:04 -05:00
VED
e7f0d4ba5b Increased test coverage for lora/qlora (#3147)
* config_val tests

* remove config val(not needed)

* config validation

* parameter freeze validation

* merge/unmerge tests

* removal unwanted

* rename

* lint

* updated lint

* Update tests/utils/lora/test_config_validation_lora.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* pytest skip + mock fix

* nitpicks

* revert some nitpicks

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-06 11:44:48 -05:00
VED
7bf6f70e96 fix total/trainable tokens log (#3344)
* fix total/trainable tokens log

* fix total/trainable tokens log
2026-01-06 09:25:17 -05:00
PraMamba
8aab807e67 feat: Add SwanLab integration for experiment tracking (#3334)
* feat(swanlab): add SwanLab integration for experiment tracking

SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training.

Features:
- Hyperparameter logging
- Training metrics tracking
- RLHF completion logging
- Performance profiling
- Configuration validation and conflict detection

Includes:
- Plugin in src/axolotl/integrations/swanlab/
- Callback in src/axolotl/utils/callbacks/swanlab.py
- Tests in tests/integrations/test_swanlab.py
- Examples in examples/swanlab/

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit

- Change use_swanlab default to True (winglian)
- Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major)
- Add safe exception handling in config fallback (CodeRabbit)
- Use context managers for file operations (CodeRabbit)
- Replace LOG.error with LOG.exception for better debugging (CodeRabbit)
- Sort __all__ alphabetically (CodeRabbit)
- Add language specifiers to README code blocks (CodeRabbit)
- Fix end-of-file newline in README (pre-commit)

Resolves actionable comments and nitpicks from CodeRabbit review.
Addresses reviewer feedback from @winglian.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* only run swanlab integration tests if package is available

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-06 09:19:18 -05:00
Wing Lian
ee59e4de97 add cu130 + torch 2.9.1 to test matrices (#3343)
* add cu130 + torch 2.9.1 to test matrices

* uv can't use pip3 directly
2026-01-05 15:24:29 -05:00
Wing Lian
4e61b8aa23 use updated version of prebuilt wheels for flash attention for cu130 (#3342)
* use updated version of prebuilt wheels for flash attention for cu130

* use elif

* fix the uv base installs of FA also

* make wget less verbose
2026-01-05 13:48:12 -05:00
222 changed files with 16292 additions and 771 deletions

View File

@@ -15,6 +15,11 @@
<!--- Include details of your testing environment, tests ran to see how --> <!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. --> <!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate) ## Screenshots (if appropriate)
## Types of changes ## Types of changes

View File

@@ -21,6 +21,8 @@ jobs:
timeout-minutes: 480 timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -32,6 +34,7 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -39,6 +42,7 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -46,6 +50,31 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -53,6 +82,23 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -79,6 +125,7 @@ jobs:
axolotlai/axolotl-base axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -89,6 +136,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
@@ -103,6 +151,8 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480 timeout-minutes: 480
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -114,6 +164,7 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -121,6 +172,7 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -128,6 +180,31 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -135,6 +212,23 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -146,6 +240,7 @@ jobs:
axolotlai/axolotl-base-uv axolotlai/axolotl-base-uv
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -156,6 +251,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,22 +20,44 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true platforms: "linux/amd64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
# - cuda: 130 platforms: "linux/amd64,linux/arm64"
# cuda_version: 13.0.0 is_latest: true
# python_version: "3.11" - cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1 # pytorch: 2.9.1
# axolotl_extras: # axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -61,6 +83,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }}
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -75,6 +98,77 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud: build-axolotl-cloud:
needs: build-axolotl needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -87,22 +181,44 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true platforms: "linux/amd64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
# - cuda: 130 is_latest: true
# cuda_version: 13.0.0 platforms: "linux/amd64,linux/arm64"
# python_version: "3.11" - cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1 # pytorch: 2.9.1
# axolotl_extras: # axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -127,6 +243,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }}
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -137,6 +254,73 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux: build-axolotl-cloud-no-tmux:
needs: build-axolotl needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -147,11 +331,11 @@ jobs:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
is_latest: is_latest: true
- cuda: 128 - cuda: 130
cuda_version: 12.8.1 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
@@ -180,6 +364,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}

View File

@@ -35,14 +35,26 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: fbgemm-gpu axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: fbgemm-gpu axolotl_extras: "fbgemm-gpu"
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@@ -64,8 +76,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run -m cicd.multigpu modal run -m cicd.multigpu

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging==23.2 pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -48,9 +48,9 @@ jobs:
id: tag id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3) run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py - name: Update version in VERSION file
run: | run: |
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
- name: Build a source dist - name: Build a source dist
run: | run: |

View File

@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p /home/runner/.cache/huggingface/hub mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@@ -48,7 +48,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |

View File

@@ -54,8 +54,13 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"] pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -70,7 +75,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1 curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/ ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
@@ -82,7 +87,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -110,10 +115,10 @@ jobs:
- name: Pre-Download dataset fixture - name: Pre-Download dataset fixture
run: | run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache - name: Show HF cache
run: hf cache scan run: hf cache ls
- name: Run tests - name: Run tests
run: | run: |
@@ -127,7 +132,7 @@ jobs:
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache - name: Show HF cache
run: hf cache scan run: hf cache ls
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
@@ -144,8 +149,13 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.11"] python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"] pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -160,7 +170,7 @@ jobs:
id: hf-cache-restore-s3 id: hf-cache-restore-s3
run: | run: |
mkdir -p ~/.cache/huggingface/hub mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1 curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/ ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python - name: Setup Python
@@ -172,7 +182,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -200,7 +210,7 @@ jobs:
axolotl --help axolotl --help
- name: Show HF cache - name: Show HF cache
run: hf cache scan run: hf cache ls
- name: Run tests - name: Run tests
run: | run: |
@@ -209,10 +219,10 @@ jobs:
pytest -v --durations=10 tests/cli/ pytest -v --durations=10 tests/cli/
- name: Show HF cache - name: Show HF cache
run: hf cache scan run: hf cache ls
gate-skip-e2e: gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist] needs: [pre-commit]
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
skip: ${{ steps.compute.outputs.skip }} skip: ${{ steps.compute.outputs.skip }}
@@ -248,16 +258,16 @@ jobs:
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 120 timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e] needs: [pre-commit, pytest]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 128 - cuda: 130
cuda_version: 12.8.1 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.12"
pytorch: 2.8.0 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
dockerfile: "Dockerfile-uv.jinja" dockerfile: "Dockerfile-uv.jinja"
@@ -316,6 +326,18 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -353,8 +375,8 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 128 - cuda: 129
cuda_version: 12.8.1 cuda_version: 12.9.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
num_gpus: 1 num_gpus: 1

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- | | --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | | `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub | | `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes | | `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory | | `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | | `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging | | `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic # type: # linear | dynamic
# factor: # float # factor: # float
# # Whether you are training a 4-bit GPTQ quantized model # # Whether you are training a 4-bit GPTQ quantized model
# gptq: true # gptq: true
# gptq_groupsize: 128 # group size # gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path # push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set. # # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set # dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub # # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model # hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub # # how to push checkpoints to hub
@@ -224,9 +223,6 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 # eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 # eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels # # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false # train_on_inputs: false
# # Group similarly sized data to minimize padding. # # Group similarly sized data to minimize padding.
@@ -352,8 +348,6 @@
# # Allow overwrite yml config using from cli # # Allow overwrite yml config using from cli
# strict: # strict:
base_model: ${BASE_MODEL} base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS} base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG} base_model_config: ${BASE_MODEL_CONFIG}
@@ -412,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE} default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH} dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB} push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES} dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY} dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID} hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY} hub_strategy: ${HUB_STRATEGY}
@@ -512,7 +506,6 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD} loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE} loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS} train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH} group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING} gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -88,7 +88,7 @@ Features:
#### Using pip #### Using pip
```bash ```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs # Download example axolotl configs, deepspeed configs

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.15.0.dev0

View File

@@ -251,7 +251,6 @@ website:
- docs/models/olmo3.qmd - docs/models/olmo3.qmd
- docs/models/trinity.qmd - docs/models/trinity.qmd
- docs/models/arcee.qmd - docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3" - section: "Ministral3"
contents: contents:
- docs/models/ministral3.qmd - docs/models/ministral3.qmd
@@ -266,6 +265,7 @@ website:
- docs/models/mistral-small.qmd - docs/models/mistral-small.qmd
- docs/models/voxtral.qmd - docs/models/voxtral.qmd
- docs/models/devstral.qmd - docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd - docs/models/llama-4.qmd
- docs/models/llama-2.qmd - docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd - docs/models/qwen3-next.qmd
@@ -320,6 +320,7 @@ website:
- docs/multipack.qmd - docs/multipack.qmd
- docs/mixed_precision.qmd - docs/mixed_precision.qmd
- docs/optimizers.qmd - docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features" - section: "Advanced Features"
contents: contents:

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN uv pip install packaging==23.2 setuptools==75.8.0 RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0 psutil RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -17,7 +17,8 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment( template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape() loader=template_loader, autoescape=select_autoescape()
) )
df_template = template_env.get_template("Dockerfile.jinja") dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = { df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -27,8 +28,11 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"), "CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)

View File

@@ -2,7 +2,7 @@
set -e set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=4 \ pytest -v --durations=10 -n2 --maxfail=3 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS="" ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$TARGETARCH" = "arm64" ]; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \ fi && \
python scripts/unsloth_install.py | sh && \ if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \ python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \ pip install pytest && \
pip cache purge pip cache purge

View File

@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8" ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10" ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118" ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -22,11 +24,17 @@ RUN apt-get update \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \ librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \ && rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& wget \ && if [ "$TARGETARCH" = "amd64" ]; then \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ MINICONDA_ARCH="x86_64"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
MINICONDA_ARCH="aarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \ && bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
@@ -35,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip cache purge python3 -m pip cache purge
@@ -51,8 +59,18 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge pip3 cache purge
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \ # Map Python version (e.g., 3.12 -> cp312)
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ # Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
fi # Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -0,0 +1,30 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

47
docker/Dockerfile-uv Normal file
View File

@@ -0,0 +1,47 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -2,9 +2,11 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION="" ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0" ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126" ARG CUDA="126"
@@ -31,12 +33,25 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ RUN if [ "$TARGETARCH" = "amd64" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI: Download a base model using the Hugging Face CLI:
```bash ```bash
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
``` ```
### 10. Create Axolotl Configuration ### 10. Create Axolotl Configuration

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options: Configuration options:
```yaml ```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate # List of tasks to evaluate
lm_eval_tasks: lm_eval_tasks:
- arc_challenge - arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results output_dir: # Directory to save evaluation results
``` ```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details. See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4 ### delinearize-llama4

View File

@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
``` ```
4. (Optional) Login to Hugging Face: 4. (Optional) Login to Hugging Face:
```{.bash} ```{.bash}
huggingface-cli login hf auth login
``` ```
## Troubleshooting {#sec-troubleshooting} ## Troubleshooting {#sec-troubleshooting}

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT. Currently, LoRA kernels are not supported for RLHF training, only SFT.
::: :::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements ## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,6 +19,7 @@ format:
- [Gemma-3n](#sec-gemma-3n) - [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2) - [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl) - [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl) - [Intern-VL](#sec-intern-vl)
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2} ### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto) - [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo) - [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl ## RLHF using Axolotl
@@ -720,6 +721,102 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO ### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
] ]
}, },
{ {

View File

@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -52,6 +52,7 @@ gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -0,0 +1,77 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too # Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_linear: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:

View File

@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -0,0 +1,77 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm4.7-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm4.7-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
```
### Expert LoRA
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router, untested but should work, not normally targeted
```
## Limitations
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]' pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -19,7 +19,6 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora adapter: qlora
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -0,0 +1,68 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -12,7 +12,6 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora adapter: qlora

View File

@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min) # Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -47,6 +47,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
tokens: tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config # save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -59,6 +59,7 @@ gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
scaling_softmax: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started ## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). 1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip: 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance 3. Install FLA for improved performance
```bash ```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2 pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
``` ```
4. Run the finetuning example: 4. Run the finetuning example:
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
``` ```
This config uses about 45.62 GiB VRAM. This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
Let us know how it goes. Happy finetuning! 🚀 Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,6 +9,8 @@ plugins:
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
quantize_moe_experts: true
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_test - path: fozziethebeat/alpaca_messages_2k_test
type: chat_template type: chat_template
@@ -25,7 +27,7 @@ sample_packing: true
lora_r: 16 lora_r: 16
lora_alpha: 8 lora_alpha: 8
lora_dropout: 0.05 lora_dropout: 0
lora_target_modules: lora_target_modules:
- linear_attn.in_proj_ba - linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz - linear_attn.in_proj_qkvz
@@ -34,12 +36,19 @@ lora_target_modules:
- shared_expert.down_proj - shared_expert.down_proj
- shared_expert.gate_proj - shared_expert.gate_proj
- shared_expert_gate - shared_expert_gate
- mlp.gate
- q_proj - q_proj
- v_proj - v_proj
- k_proj - k_proj
- o_proj - o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:

285
examples/swanlab/README.md Normal file
View File

@@ -0,0 +1,285 @@
# SwanLab Integration Examples
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
## Examples Overview
### 1. DPO with Completion Logging
**File**: `dpo-swanlab-completions.yml`
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
**Features**:
- Basic SwanLab experiment tracking
- Completion table logging (prompts, chosen/rejected responses, rewards)
- Memory-bounded buffer for long training runs
- Cloud sync configuration
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
```
---
### 2. LoRA with Performance Profiling
**File**: `lora-swanlab-profiling.yml`
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
**Features**:
- SwanLab experiment tracking
- Automatic profiling of trainer methods
- Profiling metrics visualization
- Performance optimization guidance
**Best for**: Engineers optimizing training performance and comparing different configurations
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
```
---
### 3. Full-Featured DPO Production Setup
**File**: `dpo-swanlab-full-featured.yml`
Comprehensive production-ready configuration with ALL SwanLab features enabled.
**Features**:
- Experiment tracking with team workspace
- RLHF completion logging
- Performance profiling
- Lark (Feishu) team notifications
- Private deployment support
- Production checklist and troubleshooting
**Best for**: Production RLHF training with team collaboration
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
```
---
### 4. Custom Trainer Profiling (Python)
**File**: `custom_trainer_profiling.py`
Python code examples showing how to add SwanLab profiling to custom trainers.
**Features**:
- `@swanlab_profile` decorator examples
- Context manager profiling for fine-grained timing
- `ProfilingConfig` for advanced filtering and throttling
- Multiple profiling patterns and best practices
**Best for**: Advanced users creating custom trainers
**Usage**:
```python
from custom_trainer_profiling import CustomTrainerWithProfiling
# See file for detailed examples and patterns
```
---
## Feature Matrix
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|---------|----------|-------------------|-----------|-------------------|----------------|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | (commented) | (commented) |
| lora-swanlab-profiling.yml | ✅ | (disabled) | ✅ (auto) | (commented) | (commented) |
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
---
## Configuration Quick Reference
### Basic SwanLab Setup
```yaml
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # cloud, local, offline, disabled
```
### RLHF Completion Logging
```yaml
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
```
### Lark Team Notifications
```yaml
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret # Required for production
```
### Team Workspace
```yaml
swanlab_workspace: my-research-team
```
### Private Deployment
```yaml
swanlab_web_host: https://swanlab.yourcompany.com
swanlab_api_host: https://api.swanlab.yourcompany.com
```
---
## Authentication
### Recommended: Environment Variable
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
```
### Alternative: Config File (less secure)
```yaml
swanlab_api_key: your-api-key
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret
```
---
## Common Use Cases
### Use Case 1: Migrate from WandB to SwanLab
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
```yaml
use_swanlab: true
use_wandb: false
```
### Use Case 2: Analyze DPO Model Outputs
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
```yaml
swanlab_completion_log_interval: 50 # More frequent for short training
swanlab_completion_log_interval: 200 # Less frequent for long training
```
### Use Case 3: Optimize Training Performance
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
- Baseline: `flash_attention: false, gradient_checkpointing: false`
- Flash Attention: `flash_attention: true`
- Gradient Checkpointing: `gradient_checkpointing: true`
- Both: `flash_attention: true, gradient_checkpointing: true`
Compare profiling metrics in SwanLab dashboard.
### Use Case 4: Production RLHF with Team Collaboration
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
```yaml
swanlab_workspace: ml-team
swanlab_lark_webhook_url: ...
swanlab_lark_secret: ...
```
---
## Viewing Your Experiments
### Cloud Mode
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
**Dashboard sections**:
- **Metrics**: Training loss, learning rate, profiling metrics
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
- **Config**: Hyperparameters and configuration
- **System**: Resource usage (GPU, memory, CPU)
- **Files**: Logged artifacts
### Local Mode
```bash
swanlab watch ./swanlog
# Open browser to http://localhost:5092
```
---
## Troubleshooting
### SwanLab not initializing
```bash
# Check API key
echo $SWANLAB_API_KEY
# Verify SwanLab is installed
pip show swanlab
# Check config
grep -A 5 "use_swanlab" your-config.yml
```
### Completions not appearing
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
- Check `swanlab_log_completions: true`
- Wait for `swanlab_completion_log_interval` steps
- Look for "Registered SwanLab RLHF completion logging" in logs
### Lark notifications not working
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
- Verify `SWANLAB_LARK_SECRET` is set correctly
- Check bot is added to Lark group chat
- Look for "Registered Lark notification callback" in logs
### Profiling metrics not appearing
- Verify `use_swanlab: true`
- Check SwanLab is initialized (look for init log message)
- Profiling metrics are under "profiling/" namespace
- Profiling auto-enabled when SwanLab is enabled
---
## Performance Notes
### Overhead Comparison
| Feature | Overhead per Step | Memory Usage |
|---------|------------------|--------------|
| Basic tracking | < 0.1% | ~10 MB |
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
| Profiling | < 0.1% | ~1 KB |
| **Total** | **< 0.7%** | **~10 MB** |
### Best Practices
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
2. Adjust completion log interval based on training length (100-200 steps)
3. Keep completion buffer size reasonable (128-512)
4. Profile critical path methods first (training_step, compute_loss)
5. Use ProfilingConfig to throttle high-frequency operations
---
## Further Reading
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
---
## Contributing
Found an issue or have an improvement? Please submit a PR or open an issue:
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)

View File

@@ -0,0 +1,299 @@
"""Example: Custom Trainer with SwanLab Profiling
This example demonstrates how to add SwanLab profiling to your custom trainer.
Features:
- @swanlab_profile decorator for automatic profiling
- swanlab_profiling_context for fine-grained profiling
- ProfilingConfig for advanced filtering and throttling
Usage:
1. Create your custom trainer extending AxolotlTrainer
2. Add @swanlab_profile decorators to methods you want to profile
3. Use swanlab_profiling_context for fine-grained profiling within methods
4. Enable SwanLab in your config (use_swanlab: true)
See also:
- examples/swanlab/lora-swanlab-profiling.yml for config
- src/axolotl/integrations/swanlab/profiling.py for implementation
"""
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled.
This trainer demonstrates three profiling patterns:
1. Decorator-based profiling (@swanlab_profile)
2. Context manager profiling (swanlab_profiling_context)
3. Advanced profiling with filtering (ProfilingConfig)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create custom profiling config for high-frequency operations
self.fast_op_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.5, # Only log if duration > 0.5ms
log_interval=50, # Log every 50th call
)
# ========================================================================
# Pattern 1: Decorator-based Profiling
# ========================================================================
# Best for: Methods you always want to profile
# Overhead: ~2-5 microseconds per call (negligible)
@swanlab_profile
def training_step(self, model, inputs):
"""Main training step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
"""
return super().training_step(model, inputs)
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False):
"""Loss computation - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
"""
return super().compute_loss(model, inputs, return_outputs)
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Prediction step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
"""
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# ========================================================================
# Pattern 2: Fine-grained Context Manager Profiling
# ========================================================================
# Best for: Profiling specific code blocks within a method
# Use case: When you want to profile forward vs backward separately
def complex_training_step(self, model, inputs):
"""Training step with fine-grained profiling.
Profiling metrics:
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
"""
# Profile just the forward pass
with swanlab_profiling_context(self, "forward_pass"):
outputs = model(**inputs)
loss = outputs.loss
# Profile just the backward pass
with swanlab_profiling_context(self, "backward_pass"):
loss.backward()
# Profile optimizer step
with swanlab_profiling_context(self, "optimizer_step"):
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
# ========================================================================
# Pattern 3: Advanced Profiling with Filtering
# ========================================================================
# Best for: High-frequency operations where you want to throttle logging
# Use case: Methods called 100+ times per step
def _prepare_inputs(self, inputs):
"""Prepare inputs - throttled profiling.
This method is called frequently (once per batch), so we throttle
profiling to reduce overhead:
- Only log if duration > 0.5ms (skip very fast operations)
- Only log every 50th call (reduce logging frequency)
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
"""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_op_config
):
return super()._prepare_inputs(inputs)
def _prepare_input_for_model(self, input_ids):
"""Another high-frequency operation - throttled profiling.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
"""
with swanlab_profiling_context_advanced(
self, "prepare_input_for_model", config=self.fast_op_config
):
# Your custom input preparation logic
return input_ids
# ========================================================================
# Pattern 4: Exception-safe Profiling
# ========================================================================
# Profiling is exception-safe: duration is logged even if method raises
@swanlab_profile
def potentially_failing_method(self):
"""This method may raise an exception.
SwanLab profiling will still log the duration before re-raising.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
"""
# Do some work
result = self._do_risky_computation()
# If this raises, profiling duration is still logged
if result < 0:
raise ValueError("Invalid result")
return result
def _do_risky_computation(self):
"""Placeholder for risky computation."""
return 42
# ============================================================================
# Advanced Example: Custom ProfilingConfig Per Method
# ============================================================================
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Different profiling configs for different method types
self.critical_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything on critical path
log_interval=1, # Log every call
)
self.fast_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0, # Only log if > 1ms
log_interval=100, # Log every 100th call
)
self.debug_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything
log_interval=1, # Log every call
)
def training_step(self, model, inputs):
"""Critical path - log everything."""
with swanlab_profiling_context_advanced(
self, "training_step", config=self.critical_path_config
):
return super().training_step(model, inputs)
def _prepare_inputs(self, inputs):
"""Fast path - throttle logging."""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_path_config
):
return super()._prepare_inputs(inputs)
def _debug_method(self, data):
"""Debug-only method - verbose logging."""
with swanlab_profiling_context_advanced(
self, "debug_method", config=self.debug_config
):
# Your debug logic
pass
# ============================================================================
# How to Use This Custom Trainer
# ============================================================================
"""
To use this custom trainer:
1. Save this file to your project (e.g., my_custom_trainer.py)
2. Create a config file that uses your custom trainer:
# config.yml
base_model: NousResearch/Llama-3.2-1B
# ... other config ...
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-profiling-experiment
# Optional: Specify custom trainer
# (Or modify axolotl to use your custom trainer class)
3. Run training:
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train config.yml
4. View profiling metrics in SwanLab dashboard:
- profiling/Time taken: CustomTrainerWithProfiling.training_step
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- etc.
5. Compare profiling metrics across runs:
- Run baseline without optimizations
- Run with flash_attention enabled
- Run with gradient_checkpointing enabled
- Compare profiling metrics to see performance impact
"""
# ============================================================================
# Tips for Effective Profiling
# ============================================================================
"""
1. Profile the critical path first:
- training_step, compute_loss, prediction_step
- These methods are called most frequently and have biggest impact
2. Use throttling for high-frequency operations:
- Methods called 100+ times per step
- Use log_interval=50 or log_interval=100
- Reduces profiling overhead and dashboard clutter
3. Filter noise with min_duration_ms:
- Set min_duration_ms=1.0 to skip very fast operations
- Focus on operations that actually take time
4. Compare across runs:
- Run same config multiple times to check consistency
- Compare different optimization strategies
- Track profiling trends over time
5. Monitor distributed training:
- Check for per-rank timing differences
- Look for stragglers (slower ranks)
- Identify synchronization bottlenecks
6. Disable profiling in production:
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
- DEFAULT_PROFILING_CONFIG.enabled = False
7. Exception handling:
- Profiling is exception-safe
- Duration logged even if method raises
- Useful for debugging methods that fail intermittently
"""

View File

@@ -0,0 +1,168 @@
# SwanLab DPO Training Example with Completion Logging
#
# This example demonstrates DPO (Direct Preference Optimization) training
# with SwanLab integration for experiment tracking and completion table logging.
#
# Features enabled:
# - SwanLab experiment tracking
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
# - Lark (Feishu) team notifications (optional)
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
# Model Configuration
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization
load_in_8bit: true
load_in_4bit: false
# LoRA Configuration
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# DPO Configuration
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# Dataset and Output
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-out
# Training Configuration
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# Optimization
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: dpo-training
swanlab_experiment_name: llama-3-dpo-completions-demo
swanlab_description: "DPO training with completion table logging"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-research-team
# ============================================================================
# RLHF Completion Table Logging
# ============================================================================
#
# Automatically logs model completions to SwanLab for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View the table in SwanLab dashboard under "rlhf_completions"
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 training steps
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
# Memory Usage Notes:
# - Buffer size 128: ~64 KB (default, recommended)
# - Buffer size 512: ~256 KB (for more historical completions)
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
# Performance Notes:
# - Completion logging overhead: < 0.5% per training step
# - Only logs every N steps to minimize impact
# - Memory-bounded buffer prevents memory leaks
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get real-time training notifications in your team chat
# Uncomment to enable:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # Recommended for production
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# ============================================================================
# Optional: Private SwanLab Deployment
# ============================================================================
#
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# wandb_entity:
# use_wandb: false

View File

@@ -0,0 +1,329 @@
# SwanLab Full-Featured DPO Training Example
#
# This example demonstrates ALL SwanLab integration features:
# - Experiment tracking with cloud sync
# - RLHF completion table logging
# - Performance profiling
# - Lark (Feishu) team notifications
# - Team workspace collaboration
#
# Use this as a reference for production RLHF training setups.
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
# ============================================================================
# Model Configuration
# ============================================================================
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization for efficient training
load_in_8bit: true
load_in_4bit: false
# ============================================================================
# LoRA Configuration
# ============================================================================
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true # Target all linear layers
# ============================================================================
# DPO (Direct Preference Optimization) Configuration
# ============================================================================
chat_template: llama3
rl: dpo # Enable DPO trainer
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# ============================================================================
# Dataset and Output Configuration
# ============================================================================
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-full-featured-out
# ============================================================================
# Training Configuration
# ============================================================================
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# ============================================================================
# Optimization
# ============================================================================
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# ============================================================================
# Precision and Performance
# ============================================================================
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# ============================================================================
# Checkpointing and Logging
# ============================================================================
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration - Full Configuration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# ------------------------------------------------------------------------------
# Basic SwanLab Configuration
# ------------------------------------------------------------------------------
use_swanlab: true
swanlab_project: dpo-production
swanlab_experiment_name: llama-3-dpo-full-featured-v1
swanlab_description: |
Production DPO training with all SwanLab features enabled:
- Completion table logging for qualitative analysis
- Performance profiling for optimization
- Lark notifications for team collaboration
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# ------------------------------------------------------------------------------
# Team Collaboration
# ------------------------------------------------------------------------------
# Workspace for team collaboration (shared experiments)
swanlab_workspace: ml-research-team
# Authentication (recommended: use environment variable)
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# ------------------------------------------------------------------------------
# RLHF Completion Table Logging
# ------------------------------------------------------------------------------
# Automatically logs model completions for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View in SwanLab dashboard under "rlhf_completions" table
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
# Buffer size recommendations:
# - 128: Default, ~64 KB memory (recommended for most cases)
# - 256: ~128 KB memory (this config, good for longer training)
# - 512: ~256 KB memory (maximum for very long runs)
# ------------------------------------------------------------------------------
# Lark (Feishu) Team Notifications
# ------------------------------------------------------------------------------
# Get real-time training notifications in your team chat
#
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# Recommended: Set via environment variables
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# Or set in config (less secure):
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
# unauthorized parties from sending fake notifications to your team chat.
# ------------------------------------------------------------------------------
# Performance Profiling
# ------------------------------------------------------------------------------
# Profiling is automatically enabled when SwanLab is enabled.
# Metrics logged to SwanLab under "profiling/" namespace:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# Use these metrics to:
# - Identify bottlenecks in training loop
# - Compare performance across different configurations
# - Monitor performance regressions over time
# - Debug unexpected slowdowns
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# ------------------------------------------------------------------------------
# Optional: Private SwanLab Deployment
# ------------------------------------------------------------------------------
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ------------------------------------------------------------------------------
# Optional: Model Checkpointing to SwanLab
# ------------------------------------------------------------------------------
# Log model checkpoints to SwanLab (coming soon)
swanlab_log_model: false
# ============================================================================
# Disable Other Logging Tools (Recommended)
# ============================================================================
# Using multiple logging tools simultaneously can impact performance:
# - Expected overhead: ~1-2% per logger
# - Potential config/callback conflicts
#
# For production training, use ONLY SwanLab:
# wandb_project:
# use_wandb: false
#
# use_mlflow: false
#
# use_comet: false
# ============================================================================
# Expected Training Behavior
# ============================================================================
# With this configuration, you should see:
#
# 1. SwanLab Initialization (rank 0 only):
# INFO: SwanLab initialized for project: dpo-production
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
# INFO: SwanLab mode: cloud
# INFO: SwanLab workspace: ml-research-team
#
# 2. Completion Logging (rank 0 only):
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
# (log_interval=100, max_buffer=256)
#
# 3. Lark Notifications (rank 0 only):
# INFO: Registered Lark notification callback with HMAC authentication
#
# 4. Distributed Training Detection (if multi-GPU):
# INFO: Distributed training detected (world_size=N)
# INFO: Only rank 0 will initialize SwanLab
# INFO: Other ranks will skip SwanLab to avoid conflicts
#
# 5. Training Start Notification (Lark):
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
#
# 6. Periodic Completion Logging:
# Every 100 steps, completion table is updated in SwanLab dashboard
#
# 7. Training Complete Notification (Lark):
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
# With link to SwanLab dashboard and final metrics
#
# 8. SwanLab Dashboard Shows:
# - Training metrics (loss, learning rate, etc.)
# - Completion table (rlhf_completions)
# - Profiling metrics (profiling/Time taken: ...)
# - Hyperparameters and configuration
# - System resource usage
# ============================================================================
# Production Checklist
# ============================================================================
# Before deploying to production, verify:
# ✅ SwanLab API key is set via environment variable (not in config)
# ✅ Lark webhook secret is set (required for HMAC authentication)
# ✅ Workspace is set to your team's workspace
# ✅ Experiment name is descriptive and unique
# ✅ Only SwanLab is enabled (other loggers disabled)
# ✅ Completion logging buffer size is appropriate for your training duration
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
# ✅ Test run completes successfully and shows up in SwanLab dashboard
# ✅ Lark notifications are received in team chat
# ✅ Profiling metrics are logged correctly
# ============================================================================
# Troubleshooting
# ============================================================================
# If SwanLab initialization fails:
# 1. Check SWANLAB_API_KEY environment variable is set
# 2. Verify swanlab_project is set in config
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
# 4. Verify internet connectivity (for cloud mode)
# If Lark notifications not received:
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
# 4. Check training logs for "Registered Lark notification callback"
# 5. Verify bot is added to the target Lark group chat
# If completions not appearing in SwanLab:
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
# 2. Check swanlab_log_completions is true
# 3. Wait for log_interval steps (default: 100)
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
# If profiling metrics not appearing:
# 1. Verify use_swanlab is true
# 2. Check SwanLab is initialized (check logs)
# 3. Look under "profiling/" namespace in dashboard
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
# For more help:
# - SwanLab docs: https://docs.swanlab.cn
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues

View File

@@ -0,0 +1,178 @@
# SwanLab LoRA Training Example with Performance Profiling
#
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
# for performance profiling and optimization.
#
# Features enabled:
# - SwanLab experiment tracking
# - Performance profiling (training step, forward/backward pass timing)
# - Real-time metrics visualization
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
# Model Configuration
base_model: NousResearch/Llama-3.2-1B
# Dataset Configuration
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.1
output_dir: ./outputs/lora-swanlab-profiling-out
# LoRA Configuration
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# Training Configuration
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
# Optimization
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# Loss Monitoring
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
special_tokens:
pad_token: "<|end_of_text|>"
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: lora-profiling
swanlab_experiment_name: llama-3.2-1b-profiling-demo
swanlab_description: "LoRA fine-tuning with performance profiling"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-ml-team
# ============================================================================
# Performance Profiling
# ============================================================================
#
# SwanLab automatically profiles trainer methods when enabled.
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
#
# Built-in profiling:
# - Minimal overhead (< 0.1% per step)
# - High-precision timing (microsecond accuracy)
# - Exception-safe (logs duration even if method fails)
#
# View profiling metrics in SwanLab dashboard:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# Completion logging is disabled for non-RLHF trainers
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
# ============================================================================
# Optional: Compare with Multiple Runs
# ============================================================================
#
# To compare profiling metrics across different configurations:
#
# 1. Run baseline without flash attention:
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
# flash_attention: false
#
# 2. Run with gradient checkpointing:
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
# gradient_checkpointing: true
#
# 3. Run with both:
# swanlab_experiment_name: llama-3.2-1b-optimized
# flash_attention: true
# gradient_checkpointing: true
#
# Then compare profiling metrics in SwanLab dashboard to see performance impact
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get notified when profiling experiments complete:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret
# ============================================================================
# Profiling Best Practices
# ============================================================================
#
# 1. Run multiple epochs to see profiling trends over time
# 2. Ignore first ~10 steps (warmup period, slower)
# 3. Look for outliers (steps that take significantly longer)
# 4. Compare profiling metrics before/after optimization changes
# 5. Monitor per-rank profiling in distributed training
#
# Common bottlenecks to profile:
# - training_step: Overall step time (should be consistent)
# - compute_loss: Loss computation (scales with sequence length)
# - prediction_step: Evaluation time (can be slow for large val sets)
#
# If you see inconsistent timing:
# - Check for data loading bottlenecks
# - Monitor GPU utilization (may be CPU-bound)
# - Check for gradient accumulation effects
# - Verify CUDA kernel synchronization
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# use_wandb: false

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build). 1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example: 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash ```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
``` ```
This config uses about 24.9 GiB VRAM. This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀 Let us know how it goes. Happy finetuning! 🚀
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources ## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto) - [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,5 +1,4 @@
base_model: arcee-ai/Trinity-Nano-Preview base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0 revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF

View File

@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
``` ```

View File

@@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"] requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
@@ -24,6 +24,9 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
py-modules = ["setuptools_axolotl_dynamic_dependencies"] py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass] [tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand" build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
@@ -57,3 +60,8 @@ indent-style = "space"
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "auto" line-ending = "auto"
docstring-code-format = false docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]

View File

@@ -1,26 +1,26 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.48.2 bitsandbytes==0.49.1
triton>=3.0.0 triton>=3.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.6.4 liger-kernel==0.7.0
# END section # END section
packaging==23.2 packaging==26.0
huggingface_hub>=1.1.7
huggingface_hub>=0.36.0 peft>=0.18.1
peft>=0.18.0
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==4.57.1 transformers==5.2.0
accelerate==1.12.0 accelerate==1.12.0
datasets==4.4.2 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.25.1 trl==0.28.0
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.11.5 kernels==0.12.1
trackio>=0.13.0
trackio>=0.16.1
typing-extensions>=4.15.0 typing-extensions>=4.15.0
optimum==1.16.2 optimum==1.16.2
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.13.0 torchao==0.16.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry # telemetry
posthog==6.7.11 posthog==6.7.11
mistral-common==1.8.6 mistral-common==1.8.8

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
) )

View File

@@ -1,6 +1,5 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast
import os import os
import platform import platform
import re import re
@@ -26,6 +25,12 @@ def parse_requirements(extras_require_map):
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # skip packages not compatible with OSX
skip_packages = [ skip_packages = [
@@ -62,44 +67,68 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9): if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"] extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"] extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8): elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"] extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7): elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
_install_requires.append("xformers==0.0.30") if install_xformers:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers # vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
else: else:
_install_requires.append("xformers==0.0.31") if install_xformers:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"] extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3") if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126 # since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126") _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if install_xformers:
_install_requires.append("xformers==0.0.28.post2") if patch == 0:
else: _install_requires.append("xformers==0.0.28.post2")
_install_requires.append("xformers>=0.0.28.post3") else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
if patch == 0: if install_xformers:
_install_requires.pop(_install_requires.index(xformers_version)) if patch == 0:
_install_requires.append("xformers>=0.0.27") _install_requires.pop(_install_requires.index(xformers_version))
else: _install_requires.append("xformers>=0.0.27")
_install_requires.pop(_install_requires.index(xformers_version)) else:
_install_requires.append("xformers==0.0.28.post1") _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else: else:
raise ValueError("axolotl requires torch>=2.4") raise ValueError("axolotl requires torch>=2.4")
@@ -110,15 +139,11 @@ def parse_requirements(extras_require_map):
def get_package_version(): def get_package_version():
with open( with open(
Path(os.path.dirname(os.path.abspath(__file__))) Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
/ "src"
/ "axolotl"
/ "__init__.py",
"r", "r",
encoding="utf-8", encoding="utf-8",
) as fin: ) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE) version_ = fin.read().strip()
version_ = ast.literal_eval(version_match.group(1))
return version_ return version_

View File

@@ -1,7 +1,11 @@
"""Axolotl - Train and fine-tune large language models""" """Axolotl - Train and fine-tune large language models"""
import pkgutil import pkgutil
from importlib.metadata import PackageNotFoundError, version
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.0.dev" try:
__version__ = version("axolotl")
except PackageNotFoundError:
__version__ = "unknown"

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
configure_logging() configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info) return bool(user_info)
except LocalTokenNotFoundError: except LocalTokenNotFoundError:
LOG.warning( LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." "Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
) )
return False return False
except HTTPError: except HTTPError:

View File

@@ -5,7 +5,7 @@ import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Union from typing import Any, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__) LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"} API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance() TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -208,13 +265,37 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already # If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict: if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool): cfg[key] = _coerce_value(value, cfg.get(key))
cfg[key] = bool(value)
else: # Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
cfg[key] = value for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -24,7 +24,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
""" """
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...") LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True) model = model.merge_and_unload(progressbar=True)
@@ -42,7 +41,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained( model.save_pretrained(
str(Path(cfg.output_dir) / "merged"), str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True, progressbar=True,
) )
tokenizer.save_pretrained( tokenizer.save_pretrained(

View File

@@ -14,8 +14,6 @@ from accelerate import PartialState
from accelerate.utils import ( from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version, is_torch_version,
) )
from huggingface_hub import split_torch_state_dict_into_shards from huggingface_hub import split_torch_state_dict_into_shards
@@ -40,17 +38,15 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights( def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path], checkpoint_dir: Union[str, Path],
save_path: str, save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB", max_shard_size: str = "5GB",
) -> Path: ) -> Path:
""" """
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. save under `save_path` as `model.safetensors`.
Args: Args:
checkpoint_dir: Directory where distributed checkpoint is saved. checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to. save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save. max_shard_size: Max size of model shards to save.
Returns: Returns:
@@ -76,11 +72,7 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16: if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16) state_dict[key] = value.to(torch.bfloat16)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
) )
@@ -98,19 +90,12 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors: for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors} shard = {tensor: state_dict[tensor] for tensor in tensors}
safe_save_file(
if safe_serialization: shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
safe_save_file( )
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None: if index is not None:
save_index_file = ( save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well # Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout: with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -123,13 +108,11 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights( def merge_fsdp_weights(
checkpoint_dir: str, checkpoint_dir: str,
output_path: str, output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False, remove_checkpoint_dir: bool = False,
): ):
""" """
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process. Note: this is a CPU-bound process.
@@ -138,8 +121,6 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer). The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`): output_path (`str`):
The path to save the merged checkpoint. The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging. Whether to remove the checkpoint directory after merging.
@@ -177,7 +158,7 @@ def merge_fsdp_weights(
if state.is_main_process: if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}") LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights( save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization checkpoint_dir_, output_path
) )
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}") LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir: if remove_checkpoint_dir:
@@ -210,7 +191,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights( merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir), checkpoint_dir=str(fsdp_dir),
output_path=output_path, output_path=output_path,
safe_serialization=True,
) )
state = PartialState() state = PartialState()
state.wait_for_everyone() state.wait_for_everyone()

View File

@@ -102,12 +102,10 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained( model.save_pretrained(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True, progressbar=True,
) )
tokenizer.save_pretrained( tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True, progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files, save_jinja_files=cfg.tokenizer_save_jinja_files,
) )
@@ -121,7 +119,7 @@ def do_quantize(
hub_model_id.rstrip("-") hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}" + f"-{quantization_config_to_str[type(quantization_config)]}"
) )
model.push_to_hub(hub_model_id, safe_serialization=False) model.push_to_hub(hub_model_id)
tokenizer.push_to_hub(hub_model_id) tokenizer.push_to_hub(hub_model_id)
if processor: if processor:
processor.push_to_hub(hub_model_id) processor.push_to_hub(hub_model_id)

View File

@@ -2,7 +2,7 @@
import dataclasses import dataclasses
from functools import wraps from functools import wraps
from types import NoneType from types import NoneType, UnionType
from typing import Any, Callable, Type, Union, get_args, get_origin from typing import Any, Callable, Type, Union, get_args, get_origin
import click import click
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged. returns the input type unchanged.
""" """
if get_origin(field_type) is Union and type(None) in get_args(field_type): is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
field_type = next( field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType) t for t in get_args(field_type) if not isinstance(t, NoneType)
) )
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable: def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
""" """
Create Click options from the fields of a Pydantic model. Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args: Args:
config_class: PyDantic model with fields to parse from the CLI config_class: PyDantic model with fields to parse from the CLI
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation) field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool: if field_type is bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
"gpt_oss": "GptOssDecoderLayer", "gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock", "lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE", "afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
} }

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging( def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict self, total_num_steps: int, training_args_kwargs: dict
): ):
warmup_steps = 0 warmup_steps: int | float = 0
warmup_ratio = 0.0 warmup_ratio = 0.0
if self.cfg.warmup_steps is not None: if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps warmup_steps = self.cfg.warmup_steps
@@ -230,6 +230,10 @@ class TrainerBuilderBase(abc.ABC):
else: else:
warmup_ratio = 0.03 warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1: if warmup_steps == 1:
warmup_steps = 2 warmup_steps = 2
@@ -242,7 +246,6 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1) else max(min(int(0.005 * total_num_steps), 10), 1)
) )
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict): def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -406,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict): def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps # save_strategy and save_steps
if self.cfg.save_steps: if self.cfg.save_steps:
@@ -530,9 +536,7 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio", "loraplus_lr_ratio",
"loraplus_lr_embedding", "loraplus_lr_embedding",
"output_dir", "output_dir",
"save_safetensors",
"save_only_model", "save_only_model",
"include_tokens_per_second",
"weight_decay", "weight_decay",
"seed", "seed",
"dion_momentum", "dion_momentum",
@@ -545,6 +549,7 @@ class TrainerBuilderBase(abc.ABC):
arg_map = { arg_map = {
"dion_learning_rate": "dion_lr", "dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
} }
for kwarg, cfg_arg in arg_map.items(): for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None: if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer) ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg)) callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks return callbacks
@@ -246,7 +252,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters ddp_find_unused_parameters
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
@@ -373,6 +380,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls() trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
@@ -437,7 +456,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1 or self.cfg.micro_batch_size > 1
): ):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn): if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
@@ -52,12 +51,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None trainer_cls = None
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO: if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -134,19 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults blocklist_args_kwargs.append("max_prompt_length")
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -155,10 +153,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0 self.cfg.kto_undesirable_weight or 1.0
) )
elif self.cfg.rl is RLType.GRPO: elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override from typing_extensions import override
@@ -660,11 +660,10 @@ class AxolotlTrainer(
logs["tokens/train_per_sec_per_gpu"] = round( logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
) )
if ( if "total" in self.state.tokens:
hasattr(self.state, "total_tokens") logs["tokens/total"] = int(self.state.tokens["total"].item())
and self.state.total_tokens is not None if "trainable" in self.state.tokens:
): logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
@@ -720,6 +719,20 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}") LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = ( supported_classes = (
(PreTrainedModel,) (PreTrainedModel,)
if not is_peft_available() if not is_peft_available()
@@ -730,6 +743,7 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes): if not isinstance(self.model, supported_classes):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if isinstance( if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False), self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes, supported_classes,
@@ -739,43 +753,35 @@ class AxolotlTrainer(
).save_pretrained( ).save_pretrained(
output_dir, output_dir,
state_dict=state_dict, state_dict=state_dict,
safe_serialization=self.args.save_safetensors, is_main_process=self.accelerator.is_main_process,
) )
else: else:
LOG.info( LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict." "Trainer.model is not a `PreTrainedModel`, only saving its state dict."
) )
if self.args.save_safetensors: safetensors.torch.save_file(
safetensors.torch.save_file( state_dict,
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME),
os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"},
metadata={"format": "pt"}, )
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
self.model.save_pretrained( self.model.save_pretrained(
output_dir, output_dir,
state_dict=state_dict, state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process, is_main_process=self.accelerator.is_main_process,
) )
if self.processing_class is not None: if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir) self.processing_class.save_pretrained(output_dir)
elif ( elif (
self.data_collator is not None self.data_collator is not None
and hasattr(self.data_collator, "tokenizer") and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None and self.data_collator.tokenizer is not None
): ):
LOG.info( LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
) )
save_jinja_files = True self.data_collator.tokenizer.save_pretrained(output_dir)
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files # Good practice: save your training arguments together with the trained model
self.data_collator.tokenizer.save_pretrained( torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
def tokenize_row( def tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length: int | None = None,
max_completion_length, max_completion_length: int | None = None,
add_special_tokens, add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length=max_prompt_length,
max_completion_length, max_completion_length=max_completion_length,
add_special_tokens, add_special_tokens=add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -126,8 +126,10 @@ class GRPOStrategy:
if trl.use_liger_loss is not None: if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func: if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
return grpo_args_kwargs return grpo_args_kwargs
@@ -149,6 +151,8 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes cfg.trl.reward_processing_classes
) )
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs return trainer_kwargs
@@ -159,7 +163,12 @@ class GRPOStrategy:
@classmethod @classmethod
def get_blocklist_args_kwargs(cls) -> list[str]: def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"] return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod @classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
) -> LRScheduler: ) -> LRScheduler:
""" """
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,6 +45,13 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None and self.args.cosine_min_lr_ratio is not None
) )
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore if self.lr_scheduler is None: # type: ignore
# fmt: on # fmt: on

View File

@@ -1,12 +1,10 @@
"""Module for TRL RL trainers""" """Module for TRL RL trainers"""
from trl import ( from trl import RewardTrainer
CPOTrainer, from trl.experimental.cpo import CPOTrainer
KTOTrainer, from trl.experimental.kto import KTOTrainer
ORPOTrainer, from trl.experimental.orpo import ORPOTrainer
PRMTrainer, from trl.experimental.prm import PRMTrainer
RewardTrainer,
)
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin

View File

@@ -8,7 +8,11 @@ from dataclasses import dataclass, field
from typing import Optional, Type from typing import Optional, Type
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig from trl import RewardConfig
from trl.experimental.cpo import CPOConfig
from trl.experimental.kto import KTOConfig
from trl.experimental.orpo import ORPOConfig
from trl.experimental.prm import PRMConfig
from axolotl.integrations.config import merge_training_args from axolotl.integrations.config import merge_training_args

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
``` ```
## Usage ## Usage
@@ -31,11 +31,13 @@ plugins:
## Supported Models ## Supported Models
- afmoe
- apertus - apertus
- arcee - arcee
- cohere - cohere
- cohere2 - cohere2
- deepseek_v3 - deepseek_v3
- exaone4
- gemma - gemma
- gemma2 - gemma2
- gemma3 - gemma3
@@ -45,13 +47,17 @@ plugins:
- glm - glm
- glm4 - glm4
- glm4_moe - glm4_moe
- glm4_moe_lite
- glm46v
- glm4v - glm4v
- glm4v_moe - glm4v_moe
- glm_image
- glm_moe_dsa
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
- granitemoeshared
- granitemoehybrid - granitemoehybrid
- granitemoeshared
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- internvl - internvl
@@ -72,20 +78,26 @@ plugins:
- olmo - olmo
- olmo2 - olmo2
- olmo3 - olmo3
- olmoe
- phi - phi
- phi3 - phi3
- phi4_multimodal - phi4_multimodal
- qwen2 - qwen2
- qwen2_vl
- qwen2_moe
- qwen2_5_vl - qwen2_5_vl
- qwen2_moe
- qwen2_vl
- qwen3 - qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_text
- qwen3_moe - qwen3_moe
- qwen3_next
- qwen3_vl - qwen3_vl
- qwen3_vl_moe - qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss - seed_oss
- smollm3
- step3p5
- voxtral - voxtral
## Citation ## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
) )
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like( def patch_llama_like(
self, self,
model_type: str, model_type_to_patch: str,
) -> None: ) -> None:
""" """
Generic patch for model architectures with causal lm similar to llama Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic( def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
): ):
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e
if model_type not in PATCH_FNS: if model_type_to_patch not in PATCH_FNS:
LOG.warning_once( LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type "Setting up generic cce patch for model type: %s", model_type_to_patch
) )
LOG.warning_once( LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
) )
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -0,0 +1,46 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -0,0 +1,7 @@
from .args import KernelsArgs
from .plugin import KernelsPlugin
__all__ = [
"KernelsArgs",
"KernelsPlugin",
]

View File

@@ -0,0 +1,48 @@
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
def check_use_kernels(cls, data):
if data.get("use_kernels") is not True:
LOG.warning(
"`use_kernels` must be set to True to use this. Automatically setting it to True."
)
data["use_kernels"] = True
return data
@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
)
data["experts_implementation"] = "eager"
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
return data

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import lora_ops, ops
__all__ = ["ops", "lora_ops"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,645 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import triton
import triton.language as tl
BLOCK_M = 128
ALLOW_TF32 = True
@triton.jit
def _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=True,
):
K_block = tl.arange(0, BLOCK_K)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = (
W_ptr
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
+ E_idx * stride_we
)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
return acc
def _scatter2scatter_configs():
return [
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
]
@triton.autotune(
configs=_scatter2scatter_configs(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _scatter2scatter(
X_ptr,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
W_ptr,
stride_we,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
Y_ptr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
B_ptr,
stride_be: tl.constexpr,
stride_bn: tl.constexpr,
grouped_idx_ptr,
expert_idxs_ptr,
# block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
# OUT_M,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr,
y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
no_k_mask = K % BLOCK_K == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT
acc = _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=allow_tf32,
)
if B_ptr is not None:
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def scatter2scatter(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
b=None,
x_grouped=False,
y_grouped=False,
out=None,
):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
output = out
scatter2scatter_compileable(
output,
W,
X,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
b,
x_grouped,
y_grouped,
)
return output
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
def scatter2scatter_compileable(
output: torch.Tensor,
W: torch.Tensor,
X: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
b: Optional[torch.Tensor],
x_grouped: bool,
y_grouped: bool,
) -> None:
def grid(META):
grid_num = (
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
* triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid_num
if b is None:
b = None
stride_be = stride_bn = 0
else:
stride_be, stride_bn = b.stride()
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X,
X.stride(0),
X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W,
W.stride(0),
W.stride(1),
W.stride(2),
# Y_ptr, stride_ym, stride_yn,
output,
output.stride(0),
output.stride(1),
# B_ptr, stride_be, stride_bn
b,
stride_be,
stride_bn,
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
# block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=output.size(1),
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
def _config_XtY():
return [
triton.Config(
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
),
]
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
if has_bias:
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
else:
Db = None
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
return DW, Db
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
def groupXtY_compileable(
E: int,
DW: torch.Tensor,
Db: Optional[torch.Tensor],
DY: torch.Tensor,
X: torch.Tensor,
expert_offsets: torch.Tensor,
) -> None:
def grid(META):
grid = (
E * triton.cdiv(META["K"], META["BLOCK_K"]),
triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid
if Db is None:
stride_dbe = 0
stride_dbn = 0
else:
stride_dbe, stride_dbn = Db.stride()
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY,
DY.stride(0),
DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X,
X.stride(0),
X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW,
DW.stride(0),
DW.stride(1),
DW.stride(2),
# Db_ptr, stride_dwe, stride_dbn,
Db,
stride_dbe,
stride_dbn,
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0),
N=DY.size(-1),
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
@triton.autotune(
configs=_config_XtY(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _groupXtY(
DY_ptr,
stride_dym,
stride_dyk,
X_ptr,
stride_xm,
stride_xn,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
expert_offsets_ptr,
M,
K: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = (
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
)
if (Db_ptr is not None) and (K_block_id == 0):
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=True,
)
else:
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=False,
)
@triton.jit
def _xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias: tl.constexpr,
):
if compute_bias:
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
else:
db_acc = None
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
if compute_bias:
db_acc += tl.sum(dy, axis=0)
DW_blk_ptrs = (
DW_ptr
+ E_idx * stride_dwe
+ K_block[:, None] * stride_dwk
+ N_block[None, :] * stride_dwn
)
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
if compute_bias:
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
def _config_grouping():
return [
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
return Y
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
def group_compileable(
A: torch.Tensor,
K: int,
N: int,
Y: torch.Tensor,
coeff: Optional[torch.Tensor],
has_coeff: bool,
fan_out: int,
sorted_expert_idxs: torch.Tensor,
) -> None:
def grid(META):
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A,
A.stride(0),
A.stride(1),
has_coeff,
coeff,
fan_out,
# Y_ptr, stride_yn, stride_yk,
Y,
Y.stride(0),
Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N,
K,
)
@triton.autotune(configs=_config_grouping(), key=["K"])
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
@triton.jit
def _group(
src_ptr,
stride_sn,
stride_sk,
has_coeff: tl.constexpr,
coeff_ptr,
FAN_OUT: tl.constexpr,
tgt_ptr,
stride_tn,
stride_ti,
grouped_idx_ptr,
N,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = (
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
)
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK or i < iters - 1:
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
import torch
import triton
import triton.language as tl
@triton.jit
def _single2scatter(
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
Y_ptr,
stride_ym,
stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
BLOCK_N,
)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = (
W_ptr
+ E_idx * stride_we
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
)
N_mask = N_block < N
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
K_mask = K_block < K
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
K_block += BLOCK_K
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = triton.cdiv(ydim, BLOCK_N), k
_single2scatter[grid](
X,
X.stride(0),
X.stride(1),
W,
W.stride(0),
W.stride(1),
W.stride(2),
Y,
Y.stride(0),
Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim,
N=ydim,
E=E,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32,
)
return Y

Some files were not shown because too many files have changed in this diff Show More