Commit Graph

228 Commits

Author SHA1 Message Date
Wing Lian
b3289fd190 feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
2026-03-22 13:53:19 -04:00
Wing Lian
0ee98a0309 fix token state json and mistral tokenizer issue (#3522) [skip ci]
* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
2026-03-21 22:46:10 -04:00
Avaya Aggarwal
1bcfc08c90 feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]
* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.

* feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.

* feat: introduce Pydantic schema validation for dataset, attention, and training configurations.

* feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: fix assertion in flash optimizers test to compare class names directly

* fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 08:24:44 -04:00
Wing Lian
1fc86d5295 Scattermoe LoRA optimizations (#3513)
* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
2026-03-19 23:07:42 -04:00
Wing Lian
8f3fb517b3 consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
2026-03-16 23:47:40 -04:00
Wing Lian
830e9f7eaf automatically enable tf32 if supported (#3473) [skip ci]
* automatically enable tf32 if supported

* update fixtures

* handle only when True

* Address CR comments

* address readability from pr comment

* simplify
2026-03-16 23:47:00 -04:00
Wing Lian
a36aaa70ce add gpu tests for scattermoe (#3474) [skip ci] 2026-03-07 00:00:48 -05:00
Wing Lian
876941ffd0 install flash-linear-attention (#3466)
* install flash-linear-attention

* handle prequant weights for fsdp2 and ensure loss is not zero

* fix type for cu_seqlen, uninstall causal_conv1d

* chore: lint

* uv pip uninstall doesn't need confirmation
2026-03-06 12:40:57 -05:00
NanoCode012
6a8baf8fa7 feat: add sonicmoe (#3411)
* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
2026-03-05 13:43:31 -05:00
VED
1eaf4d7418 add: support mxfp4 axo (#3375)
* mxfp4 axo

* import lint

* test for qat mxfp4

* config for mxfp4

* add qat:

* pass base config

* MXFakeQuantizeConfig

* lint

* tune config so it fits in 32GB VRAM

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-03-05 13:40:45 -05:00
Wing Lian
444020b332 mark slow tests that are timing out in CI (#3428) [skip ci] 2026-03-02 12:26:30 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
Wing Lian
fc4e37920b transformers v5 upgrade (#3272)
* Prepare for transformers v5 upgrade

* fix hf cli

* update for hf hub changes

* fix tokenizer apply_chat_template args

* remap include_tokens_per_second

* fix tps

* handle migration for warmup

* use latest hf hub

* Fix scan -> ls

* fix import

* fix for renaming of mistral common tokenizer -> backend

* update for fixed tokenziation for llama

* Skip phi35 tests for now

* remove mistral patch fixed upstream in huggingface/transformers#41439

* use namespacing for patch

* don't rely on sdist for e2e tests for now

* run modal ci without waiting too

* Fix dep for ci

* fix imports

* Fix fp8 check

* fsdp2 fixes

* fix version handling

* update fsdp version tests for new v5 behavior

* Fail multigpu tests after 3 failures

* skip known v5 broken tests for now and cleanup

* bump deps

* unmark skipped test

* re-enable test_fsdp_qlora_prequant_packed test

* increase multigpu ci timeout

* skip broken gemma3 test

* reduce timout back to original 120min now that the hanging test is skipped

* fix for un-necessary collator for pretraining with bsz=1

* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)

* torch_dtype deprecated

* load model in float32 for consistency with tests

* revert some test fixtures back

* use hf cache ls instead of scan

* don't strip fsdp_version

more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places

* Transformers v5 rc2 (#3347)

* bump dep

* use latest fbgemm, grab model config as part of fixture, un-skip test

* import AutoConfig

* don't need more problematic autoconfig when specifying config.json manually

* add fixtures for argilla ultrafeedback datasets

* download phi4-reasoning

* fix arg

* update tests for phi fast tokenizer changes

* use explicit model types for gemma3

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>

* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText

* chore: remove duplicate

* fix: attempt fix gemma3 text mode

* chore: lint

* ga release of v5

* need property setter for name_or_path for mistral tokenizer

* vllm not compatible with transformers v5

* setter for chat_template w mistral too

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2026-01-27 17:08:24 -05:00
VED
d0d26d5064 feat: Add GDPO Support (#3353)
* gdpo support - test left

* lint

* fixxes for vllm serv

* test advantages

* docss

* lint

* lint =

* gdpo simple + lint

* lint nit

* example

* lint

* trl 0.27.0

* blocklist

* test assert rmv

* add validation check for GDPO + sum_then_normalize

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-01-21 17:22:45 -05:00
VED
a6080df73c compute loss only if training and update token metric naming (#3293) [skip ci]
* compute loss only if training

* save total_tokens for checkpiont

* check if string

* refactor total_tokens/ num_tokens

* refactor 2

* rplc trainable_step/trian_per_sec_per_gpu

* lint + log trainable/tokens

* consolidate it in the callback.

* test for total_tokes aftr remuse

* check if tokenstate exist after ckpt

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-25 18:38:17 +07:00
salman
bbd3486f57 Distributed Muon Optimizer (#3264)
* init

* working

* updating configs

* removing unneeded files

* lint

* comments

* lint

* fix regex match

* bump contribs version

* comments

* fixing tests and imports

* muon imports in test v2

* test cleanup

* bump contribs version

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2025-12-19 10:43:47 -05:00
VED
dcf24fd24e feat: save checkpoint after training started (#3233)
* add:config parameters for checkpoint

* callback main

* test file_type fix

* lint

* unit

* simplify dict/obj handeling

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Delete tests/e2e/integrations/__init__.py

* remove hard code path in test

* device check

* lint

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint-2

* remove: singal based checkpoints

* lint

* remove signal tests

* add:is_main_process

* lint

* addis_d:istributed() for tests

* remove nested is_main_process

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* add user_defined_filename

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-11-13 10:21:05 -05:00
xzuyn
dd78f2e0cc Fix: warmup_steps: 0 & warmup_ratio: 0 not disabling warmup (#3254)
* fix unintentional falsy checks

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-11-11 10:32:06 +07:00
NanoCode012
11eb36585a feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger

* feat: add tests use_token_scaling

* fix: test

* fix: move check to args
2025-11-10 21:37:47 +07:00
Lê Nam Khánh
80270a92fa Fix typos in some files (#3250) [skip ci] 2025-11-07 08:21:20 -05:00
Wing Lian
4cdfdfebb5 upgrade transformers==4.57.1 and peft==0.23.1 (#3214) 2025-10-14 15:54:05 -04:00
VED
cd856b45b1 feat:add support dataset_num_processes (#3129) [skip ci]
* feat:add support dataset_num_processes

* chore

* required changes

* requested chnages

* required chnages

* required changes

* required changes

* elif get_default_process_count()

* add:del data

* Update cicd/Dockerfile.jinja

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update cicd/single_gpu.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-10-13 17:18:12 +07:00
salman
143dea4753 FSDPConfig (#3170) 2025-10-10 14:44:25 +01:00
Hitesh Sagtani
bc2ffb8204 fix: Enable KD plugin support for PEFT/LoRA adapters (#3207)
- Fix _loss_function attribute not found on base model with PEFT
- Fix mismatched attribute name (loss_function vs _loss_function)
- Set _loss_function on unwrapped base model for PEFT
- Enable previously skipped test_llama_lora_kd test
- Add test config fixes for LoRA kernel compatibility

Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
2025-10-10 08:57:00 -04:00
Wing Lian
d0e9c3c1c5 When using Ray use prepare for dataloader fixes (#3198)
* make sure to use ray prepare for dataloader fixes

* ray tests use 2.7.0+

* don't call init_distributed w ray and deepspeed

* handle dict deepspeed config

* better handling of dict deepspeed config

* use json.dumps

* guard to_dict

* wrap import for optional ray
2025-10-08 10:43:41 -04:00
Wing Lian
130637a3fa upgrade transformers to 4.57.0 (#3201)
* upgrade transformers to 4.57.0

* remove deprecated autoawq and use latest peft

* remove autoawq from setuptools script

* fix imports

* make sure torchvision is installed

* remove support for BetterTransformer

* skip fsdp_qlora_prequant test

* more robust error reporting
2025-10-08 08:43:46 -04:00
VED
a6bfbe3400 torch_dtype -> dtype (#3177)
* torch_dtype -> dtype

* torch_dtype -> dtype
2025-10-01 15:02:51 +07:00
Wing Lian
86d6ee7c05 upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0

* upgrade accelerate patch fix

* add hints when using gradient_checkpointing with DPO

* set gradient-checpointing properly
2025-09-16 14:53:01 -04:00
salman
58d67bf98d Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107) 2025-09-12 10:55:50 +01:00
Dan Saunders
1b53c49e1a text diffusion training plugin (#3067)
* diffusion training plugin

* cleanup

* nits

* fixes + improvements

* add back in reinit_weights (clobbered?); masking / pretrain fixes

* nits

* cleanup; tests draft

* sample generation, tests fixes

* fixes

* nits

* add inference support; add auto-mask token support

* nits

* nits

* progress

* simplify logging

* lint

* prefix args with diffusion_

* coderabbito

* tests fix

* nit

* nits

* cleanup + nits

* nits

* fix SFT sample gen

* fixes

* fix

* comments

* comments

* lint

* reward model lora fix

* cleanup; fix pretraining_dataset case

* gradio inference

* update cfgs

* update cfgs

* train, generation parity, cleanup

* fix

* simplify

* test

* test fix
2025-09-10 20:27:00 -04:00
Dan Saunders
231a67e70b Streaming SFT support (#3101)
* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
2025-09-02 12:08:44 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Dan Saunders
79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00
Wing Lian
130ef7c51a Various fixes for VLMs (#3063)
* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-15 10:52:57 -04:00
Wing Lian
09145de8fa upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)
* upgrade transformers==4.55.1

* also upgrade bnb

* remove bnb params4bit patch (upstreamed)

* use latest causal-conv1d

* fix patching ring-flash-attn with now missing imports

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-08-13 19:41:07 -04:00
Wing Lian
d4d84d48af fix ray train and add fsdp2 smoke test for ray trainer (#3053)
* add fsdp2 smokle test for ray trainer

* fix raytrain with fsdp2
2025-08-11 09:31:54 -04:00
Dan Saunders
d09290f2f4 Lora kernels bias support (#3025)
* lora kernels bias support

* revert rename

* nit

* lint, tests

* satisfying the rabbit
2025-08-06 20:20:08 -04:00
Wing Lian
97e86c6d47 drop old patches and code that are no longer needed (#3007) [skip ci] 2025-08-06 08:02:39 -04:00
Wing Lian
ab49d16e34 Dion optimizer support (#3014)
* Add support for Dion optimizer

* dion training kwargs

* fix var names

* no dion 8bit for now

* use updated axolotl-contribs-mit for dion optimizer

* add smoke test for dion optimizer

* add docs

* fix typo during edits

* fix test to not remove load in 8bit
2025-08-04 16:33:30 -04:00
Dan Saunders
e758343cac FSDP2 + LoRA kernels (#2992)
* impl fix

* smoke tests

* patches for fsdp2 + qlora compat

* nit

* working fix

* working fix

* fix merge

* minifying patches; update bnb dep

* renaming; adding tests

* remove duplicate test, add dora guard

* generalize __torch_function__

* revert generalization

* update comments
2025-08-03 20:05:17 -04:00
salman
294c7fe7a6 Distributed/ND-Parallel (#2977) 2025-07-31 15:25:02 -04:00
Wing Lian
7b68dfafd7 jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler

var name fix
make sure to create scheduler first

* wire things together

* more fixes

* fix for nesting scheduler and first anneal phase

* no need for relora trainer anymore since we've generalized the relora scheduler

* remove redundant relora scheduler and lint

* update relora e2e test for updated params

* need restart steps for relora test

* update quarto docs for dropped relora trainer

* update example yaml

* drop verbose arg

* min lr scale support for jagged lr

* don't let min_lr be nonetype

* cleanup args
2025-07-31 13:50:03 -04:00
Wing Lian
563f5eed7a update dependencies - liger + trl (#2987)
* update dependencies

* set dataset processes for tests

* add support for GSPO
2025-07-31 11:17:17 -04:00
Wing Lian
0ff2f172ef Act offload lora fix (#2928) [skip ci]
* fix activation offloading with lora

* update w e2e test

* add docs for error
2025-07-24 16:10:04 -04:00
Dan Saunders
208fb7b8e7 basic torchao fp8 mixed precision training (#2926)
* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
2025-07-22 16:27:47 -04:00
Dan Saunders
10ba1622f7 checkpoint model on first step callback (#2906)
* checkpoint model on first step callback

* remove debug

* add test cases; update existing tests not to save on first step

* move test out of solo

* delete

* default to False

* typo
2025-07-15 15:00:48 -04:00
Wing Lian
38359a8997 allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
2025-07-14 20:11:11 -04:00
Wing Lian
aa684122f1 upgrade peft==0.16.0 and datasets==4.0.0 (#2917) [skip ci]
* upgrade peft to 0.16.0

* upgrade datasets to 4.0.0

* refactor dupes from merge/rebase

* fix check for fsdp1 + sharded_state_dict

* use full state dict for ci
2025-07-14 20:09:26 -04:00
Wing Lian
ca4d4ef793 don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing

* add e2e test to validate preprocess cli with deepspeed

* ignore duplicate code for cfg
2025-07-14 14:19:19 -04:00