Compare commits

...

86 Commits

Author SHA1 Message Date
Wing Lian
a4a3b618e7 force torch to match when installing fa and deepspeed using uv 2026-03-04 10:00:08 -05:00
Wing Lian
b6b8db805a fix python version typo for building 3.11 (#3454) 2026-03-04 09:53:35 -05:00
Wing Lian
653f90be25 Add torch 2.10.0 to unit tests and use python 3.14 (#3450)
* Add torch 2.10.0 to unit tests and use python 3.14

* hold on python 3.14 checks due to mistral common

* add base option to matrix
2026-03-03 13:01:52 -05:00
NanoCode012
945c8aeb10 Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict

* fix: apply fix for only CP mode

* fix: add dropout check when using lora target param

* fix: re-add patch from transformers PR #39866

* feat: add moe quant to test by ved

* fix: try match target param properly end with

* fix: clear cache per param quant

* fix: attempt on-load quantize experts instead of post-load

* fix: attempt disable async load

* chore: add log

* chore: adjust log

* fix: remove cuda alloc for moe and enable async load

* chore: remove leftover logs

* chore: add extra empty cache

* fix(doc): clarify support

* fix: handle fsdp2 for paramwrapper dtensor

* feat: attempt to quant experts in 8bit mode too

* feat: attempt to release bf16 experts from vram

* feat: upgrade cce

* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init

* fix: remove unnecessary gc and empty cache

* Revert "fix: remove unnecessary gc and empty cache"

This reverts commit 1d54518990.

* fix: do not call full_tensor on non-dtensors

* fix: attempt to address fsdp2 with quant exp high loss

* fix: attempt lora quant experts wrong dim

* fix: ensure require_grad patch applied for lora 8bit

* fix: attempt lora 8bit fsdp2

* fix: attribute access on save for lora 8bit fsdp2

* fix: wrong weight attrib access

* chore(refactor): add config, re-arrange position of patches, clean
comments

* feat: add example docs

* chore: cherry pick trinity fixes from PR 3399

* chore: comments refactor; add guards

* fix: guard using wrong key

* fix: mamba save does not accept main process param

* fix: guard prevent double hook

* fix: move gc to upper scope

* chore: add comment on proxy forward patch

* fix: add comment to clarify

* feat: add test idempotency

* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter

* fix: AttributeError: 'NoneType' object has no attribute 'to'

* fix: update docs on cpu_ram_efficient_loading
2026-03-03 10:06:23 -05:00
NanoCode012
e672d37f33 fix: qwen3-next to use fla causal-conv1d to support packing (#3437
* fix: qwen3-next to use fla causal-conv1d to support packing

* fix: causal import and update doc for v5

* fix: hard fail for packing without fla
2026-03-03 09:26:46 -05:00
Wing Lian
77828d3559 uv cloud image should use uv w pip (#3449) 2026-03-02 16:39:26 -05:00
Wing Lian
4272817109 don't install torch ao on arm64 (#3448) 2026-03-02 14:24:54 -05:00
Manas Vardhan
474208b794 fix: Save de-duplicated dataset during pre-processing (#3427)
* fix: run deduplication before saving dataset during preprocessing

Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.

Fixes #2719

* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup

- Add dataset_exact_deduplication to the hash string in
  generate_dataset_hash_from_config so cached datasets are invalidated
  when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
  dataset_exact_deduplication=True, since dedup will be silently
  skipped in that configuration (both SFT and RL paths).

* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting

- Add config validator (check_deduplication_with_skip_prepare) that raises
  ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
  axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues

* refactor: inline deduplicate function per review feedback

* fix test fixture, lint

---------

Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-03-02 12:55:59 -05:00
Wing Lian
444020b332 mark slow tests that are timing out in CI (#3428) [skip ci] 2026-03-02 12:26:30 -05:00
Wing Lian
aa88c2e30b fix uv cache subcommand (#3447) 2026-03-02 12:26:08 -05:00
NanoCode012
f447bce1db fix: do not push telemetry on non-master rank (#3438) 2026-03-02 15:31:20 +07:00
kallewoof
7f23b302d1 bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() (#3435) [skip ci]
* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler()

* nit: raise if self.optimizer is also unset

* optimizer properly optional in create_scheduler()
2026-03-02 15:30:07 +07:00
Wing Lian
18f26c19ef add uv axolotl builds (#3431) 2026-02-25 14:46:02 -05:00
Robert Ronan
2b6f4a6c9b Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-02-25 11:31:11 +07:00
madScientist10
8f54b4eb25 fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders

* fix: address revision=None passed to .from_pretrained

* add tests and address review feedback for revision parameter

- Reformat modify_tokenizer_files signature and from_pretrained call
- Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision
- Add 6 unit tests for revision parameter in tokenizer/processor loaders

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-25 11:11:20 +07:00
VED
a131e4d0e5 sample gen support sft (#3240) [skip ci]
* add:parameters + callback

* sft core + logging

* indentation fix

* logger fix

* loger fix in sft

* gen sample on eval

* lint

* deprecation
2026-02-25 11:10:57 +07:00
Wing Lian
1791d87b6f build axolotl images with torch 2.10.0 (#3430) 2026-02-24 22:35:25 -05:00
Wing Lian
b40803da51 build base images for torch 2.10.0 (#3429) 2026-02-24 20:32:34 -05:00
Wing Lian
68f1b7004c ScatterMoE LoRA support (#3410)
* scattermoe lora support

* fsdp, bf16, dim fixes

* expert weights aren't needed in save for bwd since they are frozen

* use sonicmoe optim options

* update save model from upstream

* fixes per code review feedback and add tests

* revert removal of CP fix

* misc fixes
2026-02-24 14:59:55 -05:00
NanoCode012
08441fed17 fix: set allowed values for adapter config (#3415) 2026-02-23 11:39:53 -05:00
NanoCode012
86ca1e27c0 fix: update MistralProcessor to be v5 compat (#3423)
* fix: update MistralProcessor to be v5 compat

* feat: add test for mistral3 processor

* chore: comment
2026-02-23 11:39:13 -05:00
Manas Vardhan
5ed455715e feat: support dot-notation CLI args for nested config options (#3419)
* feat: support dot-notation CLI args for nested config options

Add support for overriding nested config fields (like TRL config) via
CLI using dot-notation, e.g.:
  axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1

Changes:
- args.py: Detect BaseModel subclass fields and generate dot-notation
  CLI options (--parent.child) that map to double-underscore kwargs
  (parent__child). Also fix _strip_optional_type for Python 3.10+
  union syntax (X | None).
- config.py: Handle double-underscore kwargs in load_cfg by setting
  nested dict values on the config.
- Add tests for nested option handling.

Fixes #2702

* Address CodeRabbit review: fix string parent bug, add type hints and docstring

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>

* Add type coercion for CLI kwargs and fix pre-commit issues

- Add _coerce_value() for YAML-style type inference on string CLI args
- When existing config value has a type (int/float/bool), cast to match
- When no existing value, infer type from string (true/false, ints, floats, null)
- Apply coercion to both flat and nested (dot-notation) kwargs
- Fix unused pytest import (pre-commit/ruff)
- Update tests to pass string values (matching real CLI behavior)
- Add dedicated TestCoerceValue test class

Addresses maintainer feedback on type casting for nested kwargs.

---------

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
2026-02-23 10:10:06 -05:00
Lorenzo Baraldi
3f30572d4a Fix typo in dataset_processes field (#3426)
* Fix typo in dataset_processes field

* fix: use updated config name

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-23 14:18:37 +07:00
NanoCode012
43d60c7439 bump cut-cross-entropy to 58d6572 (#3424) 2026-02-20 14:24:51 -05:00
Wing Lian
0ea252d392 update to trackio 0.16.1 (#3425) [skip ci] 2026-02-20 14:24:33 -05:00
Wing Lian
29722dec60 use bunnycdn for CI assets (#3422) [skip ci] 2026-02-20 00:09:25 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
Wing Lian
236dad3bb7 set 0.15.0.dev0 version (#3380) 2026-01-30 21:28:01 -05:00
Wing Lian
be00978bc2 tag for v0.14.0 release (#3379)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-30 14:10:27 -05:00
Wing Lian
3738978394 Add support for batched_mm, grouped_mm and scattermoe for MoE models (#3377)
* kernels plugin for moe for v5

* add support for native batched_mm or grouped_mm
2026-01-29 14:25:47 -05:00
Wing Lian
6132a30cda handle warnings from v5 upgrade (#3376) 2026-01-28 06:45:01 -05:00
NanoCode012
3dd86d35b8 feat: add new cce support for glm series and exaone4 (#3373) [skip ci] 2026-01-28 06:44:44 -05:00
salman
dd9ebaeba1 EAFT (#3366) [skip ci]
* wip eaft

* fix eaft loss fn

* adding ref

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2026-01-28 06:44:15 -05:00
Wing Lian
fc4e37920b transformers v5 upgrade (#3272)
* Prepare for transformers v5 upgrade

* fix hf cli

* update for hf hub changes

* fix tokenizer apply_chat_template args

* remap include_tokens_per_second

* fix tps

* handle migration for warmup

* use latest hf hub

* Fix scan -> ls

* fix import

* fix for renaming of mistral common tokenizer -> backend

* update for fixed tokenziation for llama

* Skip phi35 tests for now

* remove mistral patch fixed upstream in huggingface/transformers#41439

* use namespacing for patch

* don't rely on sdist for e2e tests for now

* run modal ci without waiting too

* Fix dep for ci

* fix imports

* Fix fp8 check

* fsdp2 fixes

* fix version handling

* update fsdp version tests for new v5 behavior

* Fail multigpu tests after 3 failures

* skip known v5 broken tests for now and cleanup

* bump deps

* unmark skipped test

* re-enable test_fsdp_qlora_prequant_packed test

* increase multigpu ci timeout

* skip broken gemma3 test

* reduce timout back to original 120min now that the hanging test is skipped

* fix for un-necessary collator for pretraining with bsz=1

* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)

* torch_dtype deprecated

* load model in float32 for consistency with tests

* revert some test fixtures back

* use hf cache ls instead of scan

* don't strip fsdp_version

more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places

* Transformers v5 rc2 (#3347)

* bump dep

* use latest fbgemm, grab model config as part of fixture, un-skip test

* import AutoConfig

* don't need more problematic autoconfig when specifying config.json manually

* add fixtures for argilla ultrafeedback datasets

* download phi4-reasoning

* fix arg

* update tests for phi fast tokenizer changes

* use explicit model types for gemma3

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>

* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText

* chore: remove duplicate

* fix: attempt fix gemma3 text mode

* chore: lint

* ga release of v5

* need property setter for name_or_path for mistral tokenizer

* vllm not compatible with transformers v5

* setter for chat_template w mistral too

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2026-01-27 17:08:24 -05:00
Wing Lian
a531e9d946 upgrade vllm to v0.14.0 (#3345) 2026-01-21 20:00:18 -05:00
Wing Lian
04328aeb97 cu129 targets for ci builds (#3369)
* cu129 targets for ci builds

* remove copy-paste is_latest
2026-01-21 17:24:44 -05:00
VED
d0d26d5064 feat: Add GDPO Support (#3353)
* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-21 17:22:45 -05:00
Wing Lian
8623dd8a72 strip only starting 'v' char; e.g don't strip from '.dev' (#3368) [skip ci] 2026-01-21 14:19:03 -05:00
Wing Lian
8cd75cff9f use cuda 12.9.1 and add python 3.12 to base images (#3367) 2026-01-21 13:34:14 -05:00
Wing Lian
8ab9d9ea88 Version dev (#3365) 2026-01-20 22:58:29 -05:00
Wing Lian
6e42def14b set version to v0.13.1 (#3363)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-20 08:58:32 -05:00
Wing Lian
c413480b35 upgrade transformers to 4.57.6 and peft to 0.17.1 and datasets to 4.5.0 (#3361) 2026-01-16 11:48:50 -05:00
Wing Lian
8f25124269 upgrade transformers to 4.57.5 (#3358)
* upgrade transformers to 4.57.5

* explicitly set versions for fbgemm-gpu

* handle index url for cuda version

* explicitly set cu version for fbgemm deps, skip for 130

* cu suffix not needed on version if using whl subpath
2026-01-16 11:17:43 -05:00
Wing Lian
790df757cb don't install xformers in for arm64 (#3359)
* install xformers in the base docker image

* install numba and numpy first

* set CUDA_HOME for xformers install

* Set cuda  home env

* don't install xformers by default on aarch64/arm64
2026-01-16 09:02:37 -05:00
Wing Lian
d282f32481 don't install deepspeed in arm64 images (#3357) 2026-01-14 12:03:55 -05:00
Wing Lian
6331e4a130 fix amd64 and set 2.9.1 as latest cloud image (#3356) 2026-01-14 11:56:36 -05:00
salman
1410e4474e update PR template (#3349) [skip ci] 2026-01-14 09:39:21 -05:00
Wing Lian
dc77b5bf42 fix arm64 builds (#3355)
* fix syntax  for secrets in gha yaml

* setup env for uv too

* arm64 for base  uv too

* don't build causal-conv1d or mamba for arm64 and use arm64 wheels

* fix dockerfile syntax

* fix shell syntax
2026-01-14 09:38:48 -05:00
NanoCode012
359b7ad85e fix: gemma3_text model loading vision config (#3354)
* fix: gemma3-text mode loading vision config

* fix: improve defaults to use lora kernels
2026-01-13 09:49:23 -05:00
VED
258ce8d4fa feat : scaled softmax support (#3338)
* scaled softmax

* comment

* lint

* remove egear

* validation for flash

* lint

* val imporve + neet

* fix correct softmax scale val(learned)

* learned scale val 4 ssm

* lint

* fix model_type rmv

* sdpa_atten

* test fix + lint

* test fix

* sdp_a val rmv

* flex fix

* main flash

* lint

* flex attn

* lint comment

* fix score_mod

* Update src/axolotl/utils/schemas/validation.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-13 14:33:11 +07:00
@TT
3e0bbd33ec feat: add ARM64/AArch64 build support to Dockerfile-base (#3346)
* Add support for capability to build arm64 image

* Fixing wrong variable TARGETPLATFORM bug

* Adding missing semicolons

* skip docker hub login if PR (no push) or no credentials

* Enabling arm64 builds for Dockerfile-base in Github actions

* TARGETARCH automatically default to platform arch under build

* Enabling arm64 builds for axolotl docker builds

* Enabling arm64 builds for axolotl-cloud docker build Github actions

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-12 12:00:02 -05:00
salman
4ae6f766ad bump bnb to v0.49.1 (#3351) 2026-01-12 09:42:04 -05:00
VED
e7f0d4ba5b Increased test coverage for lora/qlora (#3147)
* config_val tests

* remove config val(not needed)

* config validation

* parameter freeze validation

* merge/unmerge tests

* removal unwanted

* rename

* lint

* updated lint

* Update tests/utils/lora/test_config_validation_lora.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* pytest skip + mock fix

* nitpicks

* revert some nitpicks

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-01-06 11:44:48 -05:00
VED
7bf6f70e96 fix total/trainable tokens log (#3344)
* fix total/trainable tokens log

* fix total/trainable tokens log
2026-01-06 09:25:17 -05:00
PraMamba
8aab807e67 feat: Add SwanLab integration for experiment tracking (#3334)
* feat(swanlab): add SwanLab integration for experiment tracking

SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training.

Features:
- Hyperparameter logging
- Training metrics tracking
- RLHF completion logging
- Performance profiling
- Configuration validation and conflict detection

Includes:
- Plugin in src/axolotl/integrations/swanlab/
- Callback in src/axolotl/utils/callbacks/swanlab.py
- Tests in tests/integrations/test_swanlab.py
- Examples in examples/swanlab/

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit

- Change use_swanlab default to True (winglian)
- Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major)
- Add safe exception handling in config fallback (CodeRabbit)
- Use context managers for file operations (CodeRabbit)
- Replace LOG.error with LOG.exception for better debugging (CodeRabbit)
- Sort __all__ alphabetically (CodeRabbit)
- Add language specifiers to README code blocks (CodeRabbit)
- Fix end-of-file newline in README (pre-commit)

Resolves actionable comments and nitpicks from CodeRabbit review.
Addresses reviewer feedback from @winglian.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* only run swanlab integration tests if package is available

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-06 09:19:18 -05:00
Wing Lian
ee59e4de97 add cu130 + torch 2.9.1 to test matrices (#3343)
* add cu130 + torch 2.9.1 to test matrices

* uv can't use pip3 directly
2026-01-05 15:24:29 -05:00
Wing Lian
4e61b8aa23 use updated version of prebuilt wheels for flash attention for cu130 (#3342)
* use updated version of prebuilt wheels for flash attention for cu130

* use elif

* fix the uv base installs of FA also

* make wget less verbose
2026-01-05 13:48:12 -05:00
Wing Lian
b26ba3a5cb don't build images w cuda 130 since we don't have flash attention wheels (#3341) 2026-01-03 18:08:28 -05:00
Wing Lian
afe18ace35 deprecate torch 2.7.1 (#3339) 2026-01-01 06:52:45 -05:00
github-actions[bot]
2b199f9915 chore: update pre-commit hooks (#3340) [skip ci]
Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com>
2026-01-01 06:52:28 -05:00
Wing Lian
e73dab6df9 support pydantic 2.12 (#3328)
* upgrade pydantic to 2.12

* use latest modal version

* upgrade modal

* update modal in requirements and loosen pydantic

* upgrade modal too
2025-12-30 12:41:07 -05:00
VED
f45a97a9ff docs for checkpiont saving (#3335) [skip ci]
Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-30 12:40:32 -05:00
Wing Lian
11c0b5b256 bartch upgrade dependencies (#3299)
* upgrade dependencies

* don't use reset sessions

* downgrade transformers, upgrade other deps

* upgrade bnb to 0.49.0

* restore s3 cache

* explicit use local files w hub

* decompress and strip top level dir

* use 2 levels for strip components

* try to preserve permissions for symlinks

* use updated tar

* fix #3293 for distributed

* downgrade bnb

* fast fail after 4

* fix total tokens device

* patch accelerate CP/SP (#3309)

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-30 09:02:49 -05:00
Wing Lian
66a3de3629 build examples readmes with quarto (#3046)
* build examples readmes with quarto

* chore: formatting

* feat: dynamic build docs

* feat: add more model guides

* chore: format

* fix: collapse sidebar completely to have space for model guides

* fix: security protection for generated qmd

* fix: adjust collapse level, add new models, update links

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-25 19:17:25 +07:00
VED
a6080df73c compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training

* save total_tokens for checkpiont

* check if string

* refactor total_tokens/ num_tokens

* refactor 2

* rplc trainable_step/trian_per_sec_per_gpu

* lint + log trainable/tokens

* consolidate it in the callback.

* test for total_tokes aftr remuse

* check if tokenstate exist after ckpt

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-25 18:38:17 +07:00
NanoCode012
4f5e8a328a Feat: add MiMo and Plano (#3332) [skip-ci]
* feat: add xiaomi's mimo 7b

* fix: pin revision

* fix: update trinity docs and pin revision

* fix: wrong config name

* feat: add vram usage

* feat: add plano

* feat: update plano vram usage

* chore: comments
2025-12-25 18:09:03 +07:00
NanoCode012
418933f0d1 feat: add internvl3_5 (#3141) [skip-ci]
* feat: add internvl3_5

* fix: add timm instructions

* chore: add kimi-linear to cce doc

* feat: update internvl example

* chore: pin revision

* chore: remove from multipack

* fix: add to multimodal array

* fix: internvl use hf version

* feat: update cce

* chore: lint

* fix: list for image_size

* chore: add docs vram usage

* feat: enable cce

* fix: no need trust remote code

* fix: inconsistent timm version
2025-12-25 18:07:59 +07:00
NanoCode012
372f664c63 feat: cleanup old flex mask patch, suppress Matmul bnb warn, and misc (#3330) [skip-ci]
* feat: add pos id to flex attention for packing part 1

* feat: update to include sliding window mask patch

* fix: suppress MatMul8bitLt: inputs will be cast from warnings

* fix: remove redundant flex attention patch

* chore: update olmo docs

* feat: add validator patch for cross entropy
2025-12-25 17:56:20 +07:00
NanoCode012
97f1b1758d Feat: add kimi linear support (#3257)
* feat: add custom kimi linear patch [skip ci]

* feat: add configuration file and fix import [skip ci]

* fix: hijack tokenizer temporarily [skip ci]

* chore: remove accidental commit

* fix: attempt patch kimi remote

* fix: kwargs passsed

* fix: device for tensor

* fix: aux loss calculation

* feat: cleaned up patches order

* fix: remove duplicate tokenizer patch

* chore: add debug logs

* chore: add debug logs

* chore: debug

* Revert "chore: add debug logs"

This reverts commit da372a5f67.

* Revert "chore: add debug logs"

This reverts commit 97d1de1d7c.

* fix: KeyError: 'tokenization_kimi'

* fix: support remote_model_id in cce patch

* feat: add config preload patch

* fix: use standard aux loss calc and updated modeling

* fix: import

* feat: add kimi-linear docs and example

* chore: add note about moe kernels

* feat: update cce to include kimi-linear

* chore: lint

* chore: update main readme

* fix: patch mechanism to address comments

* chore: lint

* fix: tests

* chore: cleanup comment
2025-12-25 17:53:52 +07:00
256 changed files with 19745 additions and 1109 deletions

View File

@@ -15,6 +15,11 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -21,31 +21,12 @@ jobs:
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -53,6 +34,15 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -60,6 +50,31 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -67,6 +82,23 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -93,6 +125,7 @@ jobs:
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -103,6 +136,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -117,24 +151,12 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -142,6 +164,7 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -149,6 +172,39 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -156,6 +212,23 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -167,6 +240,7 @@ jobs:
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -177,6 +251,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -15,37 +15,49 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
platforms: "linux/amd64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -71,6 +83,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -85,6 +98,77 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -92,43 +176,49 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
platforms: "linux/amd64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -153,6 +243,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -163,6 +254,73 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -170,22 +328,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.9.1
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
@@ -212,6 +364,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}

View File

@@ -19,6 +19,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -26,27 +29,32 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras: fbgemm-gpu
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -59,7 +67,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -68,8 +76,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu
modal run -m cicd.multigpu

View File

@@ -12,16 +12,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -64,16 +64,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging==23.2
pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -48,9 +48,9 @@ jobs:
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
- name: Update version in VERSION file
run: |
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
- name: Build a source dist
run: |

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20
steps:
@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -48,7 +48,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -99,17 +99,17 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
@@ -123,7 +123,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -148,10 +148,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.9.1
num_gpus: 2
axolotl_extras:
nightly_build: "true"
@@ -165,7 +165,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -54,8 +54,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -66,12 +71,13 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -81,7 +87,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -109,7 +115,10 @@ jobs:
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache ls
- name: Run tests
run: |
@@ -122,6 +131,9 @@ jobs:
df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache ls
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
@@ -137,8 +149,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -149,12 +166,13 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -164,7 +182,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch
run: |
@@ -192,7 +210,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache scan
run: hf cache ls
- name: Run tests
run: |
@@ -200,8 +218,11 @@ jobs:
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache ls
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
needs: [pre-commit]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -237,16 +258,16 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -260,7 +281,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -292,18 +313,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
# - cuda: 128
# cuda_version: 12.8.1
# python_version: "3.11"
# pytorch: 2.7.1
# num_gpus: 1
# axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -314,7 +323,19 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
steps:
@@ -327,7 +348,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -354,10 +375,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 129
cuda_version: 12.9.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
steps:
@@ -370,7 +391,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==1.3.0.post1 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7
rev: v0.14.10
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.0
rev: v1.19.1
hooks:
- id: mypy
additional_dependencies:

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -224,9 +223,6 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -352,8 +348,6 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -412,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
@@ -512,7 +506,6 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -29,15 +29,15 @@
## 🎉 Latest Updates
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
@@ -46,8 +46,8 @@
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
@@ -77,7 +77,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.7.1
- PyTorch ≥2.8.0
### Google Colab
@@ -88,7 +88,7 @@ Features:
#### Using pip
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.15.0.dev0

View File

@@ -1,6 +1,8 @@
project:
type: website
pre-render: docs/scripts/generate_config_docs.py
pre-render:
- docs/scripts/generate_config_docs.py
- docs/scripts/generate_examples_docs.py
quartodoc:
dir: docs/api
@@ -240,6 +242,46 @@ website:
- docs/getting-started.qmd
- docs/installation.qmd
- docs/inference.qmd
- section: "Model Guides"
contents:
- docs/models/kimi-linear.qmd
- docs/models/plano.qmd
- docs/models/mimo.qmd
- docs/models/internvl3_5.qmd
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
- docs/models/ministral3/think.qmd
- docs/models/ministral3/vision.qmd
- section: "Magistral"
contents:
- docs/models/magistral.qmd
- docs/models/magistral/think.qmd
- docs/models/magistral/vision.qmd
- docs/models/ministral.qmd
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
- docs/models/qwen3.qmd
- docs/models/gemma3n.qmd
- docs/models/apertus.qmd
- docs/models/gpt-oss.qmd
- docs/models/seed-oss.qmd
- docs/models/phi.qmd
- docs/models/smolvlm2.qmd
- docs/models/granite4.qmd
- docs/models/LiquidAI.qmd
- docs/models/hunyuan.qmd
- docs/models/jamba.qmd
- docs/models/orpheus.qmd
- docs/cli.qmd
- docs/telemetry.qmd
- docs/config-reference.qmd
@@ -278,6 +320,7 @@ website:
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -17,7 +17,8 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -27,8 +28,11 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 \
pytest -v --durations=10 -n2 --maxfail=3 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -6,6 +6,7 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,13 +21,17 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
python scripts/unsloth_install.py | sh && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -2,14 +2,16 @@ ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -22,11 +24,17 @@ RUN apt-get update \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& if [ "$TARGETARCH" = "amd64" ]; then \
MINICONDA_ARCH="x86_64"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
MINICONDA_ARCH="aarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
@@ -35,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip cache purge
@@ -51,8 +59,18 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -0,0 +1,30 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

47
docker/Dockerfile-uv Normal file
View File

@@ -0,0 +1,47 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -2,9 +2,11 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -31,12 +33,25 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

2
docs/.gitignore vendored
View File

@@ -3,3 +3,5 @@ _site/
/api/*.qmd
/api/*.html
config-reference.qmd
models/**/*.qmd
models/**/*.html

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -0,0 +1,86 @@
---
title: "Checkpoint Saving"
format:
html:
toc: true
toc-depth: 2
number-sections: true
execute:
enabled: false
---
## Overview
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
## File-Based Checkpoint Trigger
### Configuration
Enable in your config:
```yaml
dynamic_checkpoint:
enabled: true
check_interval: 100 # Optional: check every N steps (default: 100)
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
```
**Options:**
- `enabled`: `true` to enable (required)
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
### How It Works
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
2. When detected, file is deleted and checkpoint is saved
3. In distributed training, rank 0 broadcasts to synchronize all ranks
### Usage
**Command line:**
```bash
touch /path/to/output_dir/axolotl_checkpoint.save
```
**Programmatic:**
```python
from pathlib import Path
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
```
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
**Custom filename:**
```yaml
dynamic_checkpoint:
enabled: true
trigger_file_path: "my_trigger.save"
```
```bash
touch /path/to/output_dir/my_trigger.save
```
## Control+C (SIGINT) Checkpoint
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
## Best Practices
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
## Example
```yaml
output_dir: ./outputs/lora-out
save_steps: 500 # Scheduled checkpoints
dynamic_checkpoint:
enabled: true
check_interval: 50
```
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4

View File

@@ -32,11 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu128-2.9.1`
## Main
@@ -74,15 +71,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu128-2.9.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.10.1`
- `0.12.0`
## Cloud

View File

@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login
hf auth login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,8 +19,10 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
## Usage
@@ -182,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
@@ -202,6 +216,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
base_model: LiquidAI/LFM2-VL-450M
```
### Intern-VL {#sec-intern-vl}
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.19`
:::
```yaml
base_model: OpenGVLab/InternVL3_5-8B
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl
@@ -720,6 +721,102 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -0,0 +1,90 @@
examples:
# December 2025
- name: kimi-linear
title: Kimi Linear
- name: plano
title: Plano Orchestrator
- name: mimo
title: MiMo
- name: internvl3_5
title: InternVL 3.5
# AllenAI
- name: olmo3
title: OLMo 3
# ArceeAI
- name: trinity
title: Trinity
- name: arcee
title: Arcee AFM
# MistralAI
- name: ministral3/think
title: Ministral 3 Thinking
- name: ministral3/vision
title: Ministral 3 Vision
- name: magistral/think
title: Magistral Thinking
- name: magistral/vision
title: Magistral Vision
- name: ministral
title: Ministral
- name: mistral-small
title: Mistral Small 3.1/3.2
- name: voxtral
title: Voxtral
- name: devstral
title: Devstral
- name: mistral
title: Mistral 7B
# Meta
- name: llama-4
title: Llama 4
- name: llama-2
title: Llama 2
# Alibaba
- name: qwen3-next
title: Qwen 3 Next
- name: qwen3
title: Qwen 3
# Google
- name: gemma3n
title: Gemma 3n
# Swiss AI
- name: apertus
title: Apertus
# GPT-OSS
- name: gpt-oss
title: GPT-OSS
- name: seed-oss
title: Seed-OSS
# Microsoft
- name: phi
title: Phi
# SmolVLM
- name: smolvlm2
title: SmolVLM 2
# IBM
- name: granite4
title: Granite 4
# LiquidAI
- name: LiquidAI
title: Liquid Foundation Models 2
# Other
- name: hunyuan
title: Hunyuan
- name: jamba
title: Jamba
- name: orpheus
title: Orpheus

View File

@@ -0,0 +1,424 @@
"""
auto generate example docs from allowlist
"""
import re
import shutil
import sys
from pathlib import Path
import yaml
# Paths
THIS = Path(__file__).resolve()
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
EXAMPLES_DIR = ROOT / "examples"
OUTPUT_DIR = ROOT / "docs" / "models"
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
def slugify(name: str) -> str:
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
s = re.sub(r"\s+", "-", s).strip("-").lower()
return s or "example"
def read_allowlist():
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
items = data.get("examples", [])
if not isinstance(items, list):
raise ValueError("`examples` must be a list in examples-allowlist.yml")
return items
def find_readme(folder: Path) -> Path | None:
for name in ("README.md", "Readme.md", "readme.md"):
p = folder / name
if p.exists():
return p
return None
def remove_first_h1(md: str) -> tuple[str, str | None]:
"""
Remove the first H1 from markdown and return (modified_md, h1_title).
The H1 is removed since we use the frontmatter title instead.
"""
lines = md.splitlines()
result = []
h1_title = None
skipped_first = False
for line in lines:
if not skipped_first and line.startswith("# "):
h1_title = line[2:].strip()
skipped_first = True
continue
result.append(line)
return "\n".join(result), h1_title
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
"""
Copy local image assets referenced in markdown to
docs/examples/assets/... and rewrite the links.
"""
dest_assets = dest_assets_root / "assets"
def repl(m):
url = m.group(1).strip()
if re.match(r"^(https?:)?//", url):
return m.group(0) # leave remote URLs
src_path = (src_dir / url).resolve()
if not src_path.exists():
return m.group(0) # leave as-is if not found
rel = src_path.relative_to(src_dir)
# Create a unique asset path based on source directory name
asset_name = src_dir.name.replace("/", "-")
dest_path = dest_assets / asset_name / rel
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
return m.group(0).replace(url, new_rel)
return IMG_RE.sub(repl, md)
def rewrite_readme_links(
md: str,
src_dir: Path,
examples_dir: Path,
parent_index_only: set,
current_src_path: str,
allowlist_entries: set,
current_output_path: str,
) -> str:
"""
Rewrite links between README.md files to point to the correct .qmd files.
"""
def repl(m):
text = m.group(1)
url = m.group(2).strip()
# Skip remote URLs and anchor links
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
return m.group(0)
# Skip non-markdown files
if not url.lower().endswith(".md"):
return m.group(0)
# Resolve the target path
try:
target_path = (src_dir / url).resolve()
# Check if target is outside examples_dir
try:
rel_path = target_path.relative_to(examples_dir)
except ValueError:
# Target is outside examples_dir, leave as-is
return m.group(0)
parts = list(rel_path.parts)
# Determine the output path for the target
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
# This is a README link
if len(parts) == 1:
# Link to root README -> index.qmd
target_output = "index.qmd"
elif len(parts) == 2:
if parts[0] == ".":
# Current directory README
target_output = "index.qmd"
else:
# subdir/README.md
parent_dir = parts[0]
if parent_dir in parent_index_only:
target_output = f"{parent_dir}/index.qmd"
else:
target_output = f"{parent_dir}.qmd"
else:
# Deeper nesting: parent/subdir/README.md
# Build the full path like "parent/subdir"
full_path = "/".join(parts[:-1]) # Remove README.md
# Check if this exact path is in allowlist
if full_path in allowlist_entries:
# This is a sub-entry with its own entry -> use .qmd
target_output = f"{full_path}.qmd"
elif parts[0] == ".":
# ./subdir/README.md -> check if subdir has own entry
subdir = parts[1]
if subdir in parent_index_only:
target_output = f"{subdir}/index.qmd"
else:
target_output = f"{subdir}.qmd"
else:
# parent/subdir where parent doesn't have own entry
target_output = f"{full_path}/index.qmd"
else:
# Regular .md file -> convert to .qmd, keep path structure
target_output = "/".join(parts)[:-2] + "qmd"
# Compute relative path from current output file to target
current_parts = current_output_path.split("/")
target_parts = target_output.split("/")
# Special case: if current is a subdir file and target is a single-component file at root
# Example: current="magistral/vision", target="magistral.qmd"
if len(current_parts) > 1 and len(target_parts) == 1:
# Current is in subdir, target is at root level
# Go up to root: ../ for each level
up_count = len(current_parts) - 1
rel_parts = [".."] * up_count + [target_parts[0]]
new_url = "/".join(rel_parts)
else:
# Find common prefix
i = 0
while (
i < min(len(current_parts) - 1, len(target_parts))
and current_parts[i] == target_parts[i]
):
i += 1
# Build relative path: go up (../) then down to target
up_count = len(current_parts) - 1 - i
rel_parts = [".."] * up_count + target_parts[i:]
if not rel_parts or rel_parts == [".."]:
# Points to same directory or parent
new_url = "/".join(rel_parts) if rel_parts else "."
else:
new_url = "/".join(rel_parts)
return f"[{text}]({new_url})"
except (ValueError, IndexError):
return m.group(0)
return LINK_RE.sub(repl, md)
def write_qmd(out_path: Path, title: str, body_md: str):
out_path.parent.mkdir(parents=True, exist_ok=True)
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
out_path.write_text(fm + body_md, encoding="utf-8")
def update_quarto_yml(generated: list[tuple[str, str, str]]):
"""
Update _quarto.yml with the generated example files in the correct order.
This keeps the sidebar in sync with the allowlist.
Model Guides is now nested under "Getting Started" section.
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
"""
quarto_yml = ROOT / "_quarto.yml"
if not quarto_yml.exists():
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
return
content = quarto_yml.read_text(encoding="utf-8")
# First pass: find all parents that have sub-entries
parents_with_subs = set()
for path, _name, _title in generated:
if "/" in path:
parent = path.split("/")[0]
parents_with_subs.add(parent)
# Build the YAML contents while preserving allowlist order
lines = []
processed_sections = set()
for path, _name, title in generated:
# Check if this is a parent page that has sub-pages
if path in parents_with_subs:
# This is a parent page with sub-pages - create a nested section
if path not in processed_sections:
processed_sections.add(path)
section_title = (
title or path.replace("-", " ").replace("_", " ").title()
)
lines.append(f' - section: "{section_title}"')
lines.append(" contents:")
# Add the parent page first
lines.append(f" - docs/models/{path}.qmd")
# Then add all sub-pages
for sub_path, _sub_name, _sub_title in generated:
if "/" in sub_path and sub_path.split("/")[0] == path:
lines.append(
f" - docs/models/{sub_path}.qmd"
)
elif "/" not in path:
# This is a flat item with no sub-pages
# Skip if it was already included as part of a parent section
if path not in processed_sections:
lines.append(f" - docs/models/{path}.qmd")
yaml_content = "\n".join(lines) + "\n"
# Pattern to match only the Model Guides contents, stopping at the next item
# in Getting Started (lines starting with 12 spaces: same level as the section)
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
def replacement(match):
prefix = match.group(1)
return prefix + "\n" + yaml_content
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
if new_content != content:
quarto_yml.write_text(new_content, encoding="utf-8")
print(f"Updated {quarto_yml}")
else:
print(f"No changes needed for {quarto_yml}")
def main():
allow = read_allowlist()
if not EXAMPLES_DIR.exists():
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
return
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
# First pass: identify which parents have their own entry vs only sub-entries
parent_entries = set() # Parents that have their own entry
parent_with_subs = set() # Parents that have sub-entries
allowlist_entries = set() # All entries in allowlist
for item in allow:
if isinstance(item, str):
name = item
else:
name = item.get("name")
allowlist_entries.add(name)
if "/" in name:
parent = name.split("/")[0]
parent_with_subs.add(parent)
else:
parent_entries.add(name)
# Parents with subs that DON'T have their own entry -> use index.qmd
parent_index_only = parent_with_subs - parent_entries
generated = []
seen_dirs = set() # Track which parent directories we've created index for
for item in allow:
if isinstance(item, str):
name = item
title = None
else:
name = item.get("name")
title = item.get("title")
if not name:
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
continue
src_dir = EXAMPLES_DIR / name
if not src_dir.exists() or not src_dir.is_dir():
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
continue
readme = find_readme(src_dir)
if not readme:
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
continue
md = readme.read_text(encoding="utf-8")
# Determine output path first (needed for link rewriting)
parts = name.split("/")
if len(parts) == 1:
# Simple case: no subdirectory
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
sidebar_path = parts[0]
else:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
child = "-".join(parts[1:]) # handle nested subdirs
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
sidebar_path = f"{parent}/{child}"
# Remove the first H1 (we use frontmatter title instead)
md, _ = remove_first_h1(md)
# Rewrite links between README files
md = rewrite_readme_links(
md,
src_dir,
EXAMPLES_DIR,
parent_index_only,
name,
allowlist_entries,
sidebar_path,
)
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
# Handle parent page generation for sub-entries
if len(parts) > 1:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
# Create parent.qmd if not already done and parent doesn't have own entry
if parent not in seen_dirs and parent in parent_index_only:
parent_readme = find_readme(EXAMPLES_DIR / parent)
if parent_readme:
parent_md = parent_readme.read_text(encoding="utf-8")
parent_md, _ = remove_first_h1(parent_md)
parent_md = rewrite_readme_links(
parent_md,
EXAMPLES_DIR / parent,
EXAMPLES_DIR,
parent_index_only,
parent,
allowlist_entries,
parent,
)
parent_md = rewrite_and_copy_assets(
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
)
parent_title = parent.replace("-", " ").replace("_", " ").title()
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
generated.append((parent, parent, parent_title))
seen_dirs.add(parent)
if not title:
title = name.replace("/", " ").replace("-", " ").title()
write_qmd(out_path, title, md)
generated.append((sidebar_path, name, title))
# Index page - preserve allowlist order
if generated:
listing = "\n".join(
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
)
index_md = (
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
+ listing
+ "\n"
)
index_fm = (
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
)
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
# Auto-update _quarto.yml to keep sidebar in sync
update_quarto_yml(generated)
if __name__ == "__main__":
main()

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
]
},
{

View File

@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -52,6 +52,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -0,0 +1,77 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -29,7 +30,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,6 +2,7 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -32,8 +33,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -0,0 +1,77 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm4.7-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm4.7-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
```
### Expert LoRA
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router, untested but should work, not normally targeted
```
## Limitations
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -0,0 +1,43 @@
# Finetune OpenGV's InternVL with Axolotl
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install `timm` for vision model support:
```bash
pip install timm==1.0.19
```
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
4. Run the finetuning example:
```bash
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
```
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Tips
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,61 @@
base_model: OpenGVLab/InternVL3_5-8B-HF
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -19,7 +19,6 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true

View File

@@ -0,0 +1,47 @@
# Finetune MoonshotAI's Kimi Linear with Axolotl
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
3. Run the finetuning example:
```bash
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
```
This config uses about 98.7GiB VRAM.
Let us know how it goes. Happy finetuning!
### TIPS
- Kimi Linear requires `trust_remote_code: true`.
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
## Optimization Guides
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
This is not yet compatible with MoE kernels from transformers v5.
## Related Resources
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,81 @@
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
trust_remote_code: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.2
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,68 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -12,7 +12,6 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora

View File

@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started

View File

@@ -5,7 +5,8 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
- Installed Axolotl from source (see [main README](../README.md))
## Getting started

View File

@@ -47,6 +47,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

39
examples/mimo/README.md Normal file
View File

@@ -0,0 +1,39 @@
# Finetune Xiaomi's MiMo with Axolotl
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the finetuning example:
```bash
axolotl train examples/mimo/mimo-7b-qlora.yaml
```
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Tips
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
## Related Resources
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,67 @@
base_model: XiaomiMiMo/MiMo-7B-RL
trust_remote_code: true
revision_of_model: 6299b5a
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# CCE - N/A as of now
# plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -59,6 +59,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started

View File

@@ -5,7 +5,8 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
- Installed Axolotl from source (see [main README](../README.md))
## Getting started

View File

@@ -5,6 +5,7 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
## Getting Started

View File

@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### TIPS

View File

@@ -42,10 +42,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

42
examples/plano/README.md Normal file
View File

@@ -0,0 +1,42 @@
# Finetune Katanemo's Plano-Orchestrator with Axolotl
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/plano/plano-4b-qlora.yaml
```
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### Orchestration Prompt
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
### Tips
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Plano GitHub](https://github.com/katanemo/plano)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: katanemo/Plano-Orchestrator-4B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
chat_template: qwen3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
4. Run the finetuning example:
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 45.62 GiB VRAM.
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,6 +9,8 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -25,7 +27,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -34,12 +36,19 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

285
examples/swanlab/README.md Normal file
View File

@@ -0,0 +1,285 @@
# SwanLab Integration Examples
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
## Examples Overview
### 1. DPO with Completion Logging
**File**: `dpo-swanlab-completions.yml`
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
**Features**:
- Basic SwanLab experiment tracking
- Completion table logging (prompts, chosen/rejected responses, rewards)
- Memory-bounded buffer for long training runs
- Cloud sync configuration
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
```
---
### 2. LoRA with Performance Profiling
**File**: `lora-swanlab-profiling.yml`
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
**Features**:
- SwanLab experiment tracking
- Automatic profiling of trainer methods
- Profiling metrics visualization
- Performance optimization guidance
**Best for**: Engineers optimizing training performance and comparing different configurations
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
```
---
### 3. Full-Featured DPO Production Setup
**File**: `dpo-swanlab-full-featured.yml`
Comprehensive production-ready configuration with ALL SwanLab features enabled.
**Features**:
- Experiment tracking with team workspace
- RLHF completion logging
- Performance profiling
- Lark (Feishu) team notifications
- Private deployment support
- Production checklist and troubleshooting
**Best for**: Production RLHF training with team collaboration
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
```
---
### 4. Custom Trainer Profiling (Python)
**File**: `custom_trainer_profiling.py`
Python code examples showing how to add SwanLab profiling to custom trainers.
**Features**:
- `@swanlab_profile` decorator examples
- Context manager profiling for fine-grained timing
- `ProfilingConfig` for advanced filtering and throttling
- Multiple profiling patterns and best practices
**Best for**: Advanced users creating custom trainers
**Usage**:
```python
from custom_trainer_profiling import CustomTrainerWithProfiling
# See file for detailed examples and patterns
```
---
## Feature Matrix
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|---------|----------|-------------------|-----------|-------------------|----------------|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | (commented) | (commented) |
| lora-swanlab-profiling.yml | ✅ | (disabled) | ✅ (auto) | (commented) | (commented) |
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
---
## Configuration Quick Reference
### Basic SwanLab Setup
```yaml
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # cloud, local, offline, disabled
```
### RLHF Completion Logging
```yaml
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
```
### Lark Team Notifications
```yaml
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret # Required for production
```
### Team Workspace
```yaml
swanlab_workspace: my-research-team
```
### Private Deployment
```yaml
swanlab_web_host: https://swanlab.yourcompany.com
swanlab_api_host: https://api.swanlab.yourcompany.com
```
---
## Authentication
### Recommended: Environment Variable
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
```
### Alternative: Config File (less secure)
```yaml
swanlab_api_key: your-api-key
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret
```
---
## Common Use Cases
### Use Case 1: Migrate from WandB to SwanLab
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
```yaml
use_swanlab: true
use_wandb: false
```
### Use Case 2: Analyze DPO Model Outputs
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
```yaml
swanlab_completion_log_interval: 50 # More frequent for short training
swanlab_completion_log_interval: 200 # Less frequent for long training
```
### Use Case 3: Optimize Training Performance
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
- Baseline: `flash_attention: false, gradient_checkpointing: false`
- Flash Attention: `flash_attention: true`
- Gradient Checkpointing: `gradient_checkpointing: true`
- Both: `flash_attention: true, gradient_checkpointing: true`
Compare profiling metrics in SwanLab dashboard.
### Use Case 4: Production RLHF with Team Collaboration
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
```yaml
swanlab_workspace: ml-team
swanlab_lark_webhook_url: ...
swanlab_lark_secret: ...
```
---
## Viewing Your Experiments
### Cloud Mode
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
**Dashboard sections**:
- **Metrics**: Training loss, learning rate, profiling metrics
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
- **Config**: Hyperparameters and configuration
- **System**: Resource usage (GPU, memory, CPU)
- **Files**: Logged artifacts
### Local Mode
```bash
swanlab watch ./swanlog
# Open browser to http://localhost:5092
```
---
## Troubleshooting
### SwanLab not initializing
```bash
# Check API key
echo $SWANLAB_API_KEY
# Verify SwanLab is installed
pip show swanlab
# Check config
grep -A 5 "use_swanlab" your-config.yml
```
### Completions not appearing
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
- Check `swanlab_log_completions: true`
- Wait for `swanlab_completion_log_interval` steps
- Look for "Registered SwanLab RLHF completion logging" in logs
### Lark notifications not working
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
- Verify `SWANLAB_LARK_SECRET` is set correctly
- Check bot is added to Lark group chat
- Look for "Registered Lark notification callback" in logs
### Profiling metrics not appearing
- Verify `use_swanlab: true`
- Check SwanLab is initialized (look for init log message)
- Profiling metrics are under "profiling/" namespace
- Profiling auto-enabled when SwanLab is enabled
---
## Performance Notes
### Overhead Comparison
| Feature | Overhead per Step | Memory Usage |
|---------|------------------|--------------|
| Basic tracking | < 0.1% | ~10 MB |
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
| Profiling | < 0.1% | ~1 KB |
| **Total** | **< 0.7%** | **~10 MB** |
### Best Practices
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
2. Adjust completion log interval based on training length (100-200 steps)
3. Keep completion buffer size reasonable (128-512)
4. Profile critical path methods first (training_step, compute_loss)
5. Use ProfilingConfig to throttle high-frequency operations
---
## Further Reading
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
---
## Contributing
Found an issue or have an improvement? Please submit a PR or open an issue:
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)

View File

@@ -0,0 +1,299 @@
"""Example: Custom Trainer with SwanLab Profiling
This example demonstrates how to add SwanLab profiling to your custom trainer.
Features:
- @swanlab_profile decorator for automatic profiling
- swanlab_profiling_context for fine-grained profiling
- ProfilingConfig for advanced filtering and throttling
Usage:
1. Create your custom trainer extending AxolotlTrainer
2. Add @swanlab_profile decorators to methods you want to profile
3. Use swanlab_profiling_context for fine-grained profiling within methods
4. Enable SwanLab in your config (use_swanlab: true)
See also:
- examples/swanlab/lora-swanlab-profiling.yml for config
- src/axolotl/integrations/swanlab/profiling.py for implementation
"""
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled.
This trainer demonstrates three profiling patterns:
1. Decorator-based profiling (@swanlab_profile)
2. Context manager profiling (swanlab_profiling_context)
3. Advanced profiling with filtering (ProfilingConfig)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create custom profiling config for high-frequency operations
self.fast_op_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.5, # Only log if duration > 0.5ms
log_interval=50, # Log every 50th call
)
# ========================================================================
# Pattern 1: Decorator-based Profiling
# ========================================================================
# Best for: Methods you always want to profile
# Overhead: ~2-5 microseconds per call (negligible)
@swanlab_profile
def training_step(self, model, inputs):
"""Main training step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
"""
return super().training_step(model, inputs)
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False):
"""Loss computation - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
"""
return super().compute_loss(model, inputs, return_outputs)
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Prediction step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
"""
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# ========================================================================
# Pattern 2: Fine-grained Context Manager Profiling
# ========================================================================
# Best for: Profiling specific code blocks within a method
# Use case: When you want to profile forward vs backward separately
def complex_training_step(self, model, inputs):
"""Training step with fine-grained profiling.
Profiling metrics:
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
"""
# Profile just the forward pass
with swanlab_profiling_context(self, "forward_pass"):
outputs = model(**inputs)
loss = outputs.loss
# Profile just the backward pass
with swanlab_profiling_context(self, "backward_pass"):
loss.backward()
# Profile optimizer step
with swanlab_profiling_context(self, "optimizer_step"):
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
# ========================================================================
# Pattern 3: Advanced Profiling with Filtering
# ========================================================================
# Best for: High-frequency operations where you want to throttle logging
# Use case: Methods called 100+ times per step
def _prepare_inputs(self, inputs):
"""Prepare inputs - throttled profiling.
This method is called frequently (once per batch), so we throttle
profiling to reduce overhead:
- Only log if duration > 0.5ms (skip very fast operations)
- Only log every 50th call (reduce logging frequency)
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
"""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_op_config
):
return super()._prepare_inputs(inputs)
def _prepare_input_for_model(self, input_ids):
"""Another high-frequency operation - throttled profiling.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
"""
with swanlab_profiling_context_advanced(
self, "prepare_input_for_model", config=self.fast_op_config
):
# Your custom input preparation logic
return input_ids
# ========================================================================
# Pattern 4: Exception-safe Profiling
# ========================================================================
# Profiling is exception-safe: duration is logged even if method raises
@swanlab_profile
def potentially_failing_method(self):
"""This method may raise an exception.
SwanLab profiling will still log the duration before re-raising.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
"""
# Do some work
result = self._do_risky_computation()
# If this raises, profiling duration is still logged
if result < 0:
raise ValueError("Invalid result")
return result
def _do_risky_computation(self):
"""Placeholder for risky computation."""
return 42
# ============================================================================
# Advanced Example: Custom ProfilingConfig Per Method
# ============================================================================
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Different profiling configs for different method types
self.critical_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything on critical path
log_interval=1, # Log every call
)
self.fast_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0, # Only log if > 1ms
log_interval=100, # Log every 100th call
)
self.debug_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything
log_interval=1, # Log every call
)
def training_step(self, model, inputs):
"""Critical path - log everything."""
with swanlab_profiling_context_advanced(
self, "training_step", config=self.critical_path_config
):
return super().training_step(model, inputs)
def _prepare_inputs(self, inputs):
"""Fast path - throttle logging."""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_path_config
):
return super()._prepare_inputs(inputs)
def _debug_method(self, data):
"""Debug-only method - verbose logging."""
with swanlab_profiling_context_advanced(
self, "debug_method", config=self.debug_config
):
# Your debug logic
pass
# ============================================================================
# How to Use This Custom Trainer
# ============================================================================
"""
To use this custom trainer:
1. Save this file to your project (e.g., my_custom_trainer.py)
2. Create a config file that uses your custom trainer:
# config.yml
base_model: NousResearch/Llama-3.2-1B
# ... other config ...
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-profiling-experiment
# Optional: Specify custom trainer
# (Or modify axolotl to use your custom trainer class)
3. Run training:
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train config.yml
4. View profiling metrics in SwanLab dashboard:
- profiling/Time taken: CustomTrainerWithProfiling.training_step
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- etc.
5. Compare profiling metrics across runs:
- Run baseline without optimizations
- Run with flash_attention enabled
- Run with gradient_checkpointing enabled
- Compare profiling metrics to see performance impact
"""
# ============================================================================
# Tips for Effective Profiling
# ============================================================================
"""
1. Profile the critical path first:
- training_step, compute_loss, prediction_step
- These methods are called most frequently and have biggest impact
2. Use throttling for high-frequency operations:
- Methods called 100+ times per step
- Use log_interval=50 or log_interval=100
- Reduces profiling overhead and dashboard clutter
3. Filter noise with min_duration_ms:
- Set min_duration_ms=1.0 to skip very fast operations
- Focus on operations that actually take time
4. Compare across runs:
- Run same config multiple times to check consistency
- Compare different optimization strategies
- Track profiling trends over time
5. Monitor distributed training:
- Check for per-rank timing differences
- Look for stragglers (slower ranks)
- Identify synchronization bottlenecks
6. Disable profiling in production:
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
- DEFAULT_PROFILING_CONFIG.enabled = False
7. Exception handling:
- Profiling is exception-safe
- Duration logged even if method raises
- Useful for debugging methods that fail intermittently
"""

View File

@@ -0,0 +1,168 @@
# SwanLab DPO Training Example with Completion Logging
#
# This example demonstrates DPO (Direct Preference Optimization) training
# with SwanLab integration for experiment tracking and completion table logging.
#
# Features enabled:
# - SwanLab experiment tracking
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
# - Lark (Feishu) team notifications (optional)
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
# Model Configuration
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization
load_in_8bit: true
load_in_4bit: false
# LoRA Configuration
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# DPO Configuration
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# Dataset and Output
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-out
# Training Configuration
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# Optimization
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: dpo-training
swanlab_experiment_name: llama-3-dpo-completions-demo
swanlab_description: "DPO training with completion table logging"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-research-team
# ============================================================================
# RLHF Completion Table Logging
# ============================================================================
#
# Automatically logs model completions to SwanLab for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View the table in SwanLab dashboard under "rlhf_completions"
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 training steps
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
# Memory Usage Notes:
# - Buffer size 128: ~64 KB (default, recommended)
# - Buffer size 512: ~256 KB (for more historical completions)
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
# Performance Notes:
# - Completion logging overhead: < 0.5% per training step
# - Only logs every N steps to minimize impact
# - Memory-bounded buffer prevents memory leaks
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get real-time training notifications in your team chat
# Uncomment to enable:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # Recommended for production
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# ============================================================================
# Optional: Private SwanLab Deployment
# ============================================================================
#
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# wandb_entity:
# use_wandb: false

View File

@@ -0,0 +1,329 @@
# SwanLab Full-Featured DPO Training Example
#
# This example demonstrates ALL SwanLab integration features:
# - Experiment tracking with cloud sync
# - RLHF completion table logging
# - Performance profiling
# - Lark (Feishu) team notifications
# - Team workspace collaboration
#
# Use this as a reference for production RLHF training setups.
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
# ============================================================================
# Model Configuration
# ============================================================================
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization for efficient training
load_in_8bit: true
load_in_4bit: false
# ============================================================================
# LoRA Configuration
# ============================================================================
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true # Target all linear layers
# ============================================================================
# DPO (Direct Preference Optimization) Configuration
# ============================================================================
chat_template: llama3
rl: dpo # Enable DPO trainer
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# ============================================================================
# Dataset and Output Configuration
# ============================================================================
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-full-featured-out
# ============================================================================
# Training Configuration
# ============================================================================
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# ============================================================================
# Optimization
# ============================================================================
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# ============================================================================
# Precision and Performance
# ============================================================================
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# ============================================================================
# Checkpointing and Logging
# ============================================================================
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration - Full Configuration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# ------------------------------------------------------------------------------
# Basic SwanLab Configuration
# ------------------------------------------------------------------------------
use_swanlab: true
swanlab_project: dpo-production
swanlab_experiment_name: llama-3-dpo-full-featured-v1
swanlab_description: |
Production DPO training with all SwanLab features enabled:
- Completion table logging for qualitative analysis
- Performance profiling for optimization
- Lark notifications for team collaboration
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# ------------------------------------------------------------------------------
# Team Collaboration
# ------------------------------------------------------------------------------
# Workspace for team collaboration (shared experiments)
swanlab_workspace: ml-research-team
# Authentication (recommended: use environment variable)
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# ------------------------------------------------------------------------------
# RLHF Completion Table Logging
# ------------------------------------------------------------------------------
# Automatically logs model completions for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View in SwanLab dashboard under "rlhf_completions" table
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
# Buffer size recommendations:
# - 128: Default, ~64 KB memory (recommended for most cases)
# - 256: ~128 KB memory (this config, good for longer training)
# - 512: ~256 KB memory (maximum for very long runs)
# ------------------------------------------------------------------------------
# Lark (Feishu) Team Notifications
# ------------------------------------------------------------------------------
# Get real-time training notifications in your team chat
#
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# Recommended: Set via environment variables
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# Or set in config (less secure):
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
# unauthorized parties from sending fake notifications to your team chat.
# ------------------------------------------------------------------------------
# Performance Profiling
# ------------------------------------------------------------------------------
# Profiling is automatically enabled when SwanLab is enabled.
# Metrics logged to SwanLab under "profiling/" namespace:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# Use these metrics to:
# - Identify bottlenecks in training loop
# - Compare performance across different configurations
# - Monitor performance regressions over time
# - Debug unexpected slowdowns
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# ------------------------------------------------------------------------------
# Optional: Private SwanLab Deployment
# ------------------------------------------------------------------------------
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ------------------------------------------------------------------------------
# Optional: Model Checkpointing to SwanLab
# ------------------------------------------------------------------------------
# Log model checkpoints to SwanLab (coming soon)
swanlab_log_model: false
# ============================================================================
# Disable Other Logging Tools (Recommended)
# ============================================================================
# Using multiple logging tools simultaneously can impact performance:
# - Expected overhead: ~1-2% per logger
# - Potential config/callback conflicts
#
# For production training, use ONLY SwanLab:
# wandb_project:
# use_wandb: false
#
# use_mlflow: false
#
# use_comet: false
# ============================================================================
# Expected Training Behavior
# ============================================================================
# With this configuration, you should see:
#
# 1. SwanLab Initialization (rank 0 only):
# INFO: SwanLab initialized for project: dpo-production
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
# INFO: SwanLab mode: cloud
# INFO: SwanLab workspace: ml-research-team
#
# 2. Completion Logging (rank 0 only):
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
# (log_interval=100, max_buffer=256)
#
# 3. Lark Notifications (rank 0 only):
# INFO: Registered Lark notification callback with HMAC authentication
#
# 4. Distributed Training Detection (if multi-GPU):
# INFO: Distributed training detected (world_size=N)
# INFO: Only rank 0 will initialize SwanLab
# INFO: Other ranks will skip SwanLab to avoid conflicts
#
# 5. Training Start Notification (Lark):
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
#
# 6. Periodic Completion Logging:
# Every 100 steps, completion table is updated in SwanLab dashboard
#
# 7. Training Complete Notification (Lark):
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
# With link to SwanLab dashboard and final metrics
#
# 8. SwanLab Dashboard Shows:
# - Training metrics (loss, learning rate, etc.)
# - Completion table (rlhf_completions)
# - Profiling metrics (profiling/Time taken: ...)
# - Hyperparameters and configuration
# - System resource usage
# ============================================================================
# Production Checklist
# ============================================================================
# Before deploying to production, verify:
# ✅ SwanLab API key is set via environment variable (not in config)
# ✅ Lark webhook secret is set (required for HMAC authentication)
# ✅ Workspace is set to your team's workspace
# ✅ Experiment name is descriptive and unique
# ✅ Only SwanLab is enabled (other loggers disabled)
# ✅ Completion logging buffer size is appropriate for your training duration
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
# ✅ Test run completes successfully and shows up in SwanLab dashboard
# ✅ Lark notifications are received in team chat
# ✅ Profiling metrics are logged correctly
# ============================================================================
# Troubleshooting
# ============================================================================
# If SwanLab initialization fails:
# 1. Check SWANLAB_API_KEY environment variable is set
# 2. Verify swanlab_project is set in config
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
# 4. Verify internet connectivity (for cloud mode)
# If Lark notifications not received:
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
# 4. Check training logs for "Registered Lark notification callback"
# 5. Verify bot is added to the target Lark group chat
# If completions not appearing in SwanLab:
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
# 2. Check swanlab_log_completions is true
# 3. Wait for log_interval steps (default: 100)
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
# If profiling metrics not appearing:
# 1. Verify use_swanlab is true
# 2. Check SwanLab is initialized (check logs)
# 3. Look under "profiling/" namespace in dashboard
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
# For more help:
# - SwanLab docs: https://docs.swanlab.cn
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues

View File

@@ -0,0 +1,178 @@
# SwanLab LoRA Training Example with Performance Profiling
#
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
# for performance profiling and optimization.
#
# Features enabled:
# - SwanLab experiment tracking
# - Performance profiling (training step, forward/backward pass timing)
# - Real-time metrics visualization
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
# Model Configuration
base_model: NousResearch/Llama-3.2-1B
# Dataset Configuration
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.1
output_dir: ./outputs/lora-swanlab-profiling-out
# LoRA Configuration
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# Training Configuration
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
# Optimization
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# Loss Monitoring
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
special_tokens:
pad_token: "<|end_of_text|>"
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: lora-profiling
swanlab_experiment_name: llama-3.2-1b-profiling-demo
swanlab_description: "LoRA fine-tuning with performance profiling"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-ml-team
# ============================================================================
# Performance Profiling
# ============================================================================
#
# SwanLab automatically profiles trainer methods when enabled.
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
#
# Built-in profiling:
# - Minimal overhead (< 0.1% per step)
# - High-precision timing (microsecond accuracy)
# - Exception-safe (logs duration even if method fails)
#
# View profiling metrics in SwanLab dashboard:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# Completion logging is disabled for non-RLHF trainers
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
# ============================================================================
# Optional: Compare with Multiple Runs
# ============================================================================
#
# To compare profiling metrics across different configurations:
#
# 1. Run baseline without flash attention:
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
# flash_attention: false
#
# 2. Run with gradient checkpointing:
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
# gradient_checkpointing: true
#
# 3. Run with both:
# swanlab_experiment_name: llama-3.2-1b-optimized
# flash_attention: true
# gradient_checkpointing: true
#
# Then compare profiling metrics in SwanLab dashboard to see performance impact
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get notified when profiling experiments complete:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret
# ============================================================================
# Profiling Best Practices
# ============================================================================
#
# 1. Run multiple epochs to see profiling trends over time
# 2. Ignore first ~10 steps (warmup period, slower)
# 3. Look for outliers (steps that take significantly longer)
# 4. Compare profiling metrics before/after optimization changes
# 5. Monitor per-rank profiling in distributed training
#
# Common bottlenecks to profile:
# - training_step: Overall step time (should be consistent)
# - compute_loss: Loss computation (scales with sequence length)
# - prediction_step: Evaluation time (can be slow for large val sets)
#
# If you see inconsistent timing:
# - Check for data loading bottlenecks
# - Monitor GPU utilization (may be CPU-bound)
# - Check for gradient accumulation effects
# - Verify CUDA kernel synchronization
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# use_wandb: false

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -1,5 +1,5 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
build-backend = "setuptools.build_meta"
[project]
@@ -24,6 +24,9 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
@@ -57,3 +60,8 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]

View File

@@ -1,35 +1,35 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.48.2
triton>=3.0.0
bitsandbytes==0.49.1
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.4
liger-kernel==0.7.0
# END section
packaging==23.2
huggingface_hub>=0.36.0
peft>=0.18.0
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==4.57.1
accelerate==1.11.0
datasets==4.4.1
deepspeed>=0.17.0
trl==0.25.0
transformers==5.2.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
hf_xet==1.2.0
kernels>=0.9.0
trackio>=0.13.0
typing_extensions>=4.14.0
kernels==0.12.1
trackio>=0.16.1
typing-extensions>=4.15.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio>=6.2.0,<7.0
modal==1.0.2
pydantic>=2.10.6,<2.12
modal==1.3.0.post1
pydantic>=2.10.6
addict
fire
PyYAML>=6.0
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.16.0
openenv-core==0.1.0
schedulefree==1.4.1
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.6
mistral-common==1.8.8

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
)

View File

@@ -1,6 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
@@ -26,6 +25,12 @@ def parse_requirements(extras_require_map):
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -62,44 +67,68 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
if install_xformers:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
if install_xformers:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
if install_xformers:
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
if install_xformers:
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
@@ -110,15 +139,11 @@ def parse_requirements(extras_require_map):
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
version_ = fin.read().strip()
return version_
@@ -156,7 +181,7 @@ extras_require = {
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",
"ray[train]>=2.52.1",
],
"vllm": [
"vllm==0.10.0",

View File

@@ -1,7 +1,11 @@
"""Axolotl - Train and fine-tune large language models"""
import pkgutil
from importlib.metadata import PackageNotFoundError, version
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.0.dev"
try:
__version__ = version("axolotl")
except PackageNotFoundError:
__version__ = "unknown"

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -24,8 +24,7 @@ if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu128-2.9.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1"
docker_tag = "main-py3.11-cu128-2.9.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from typing import Any, Optional, Union
from urllib.parse import urlparse
import requests
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -208,13 +265,37 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
try:
device_props = torch.cuda.get_device_properties("cuda")

Some files were not shown because too many files have changed in this diff Show More