Compare commits

...

83 Commits

Author SHA1 Message Date
Wing Lian
b5198d8734 granite chat multipack support and example 2025-08-02 20:57:00 -04:00
Wing Lian
4ab6a1bd7e add support for granite chat templates 2025-08-02 11:29:03 -04:00
Wing Lian
5639552064 prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)
* prevent usage of low bit ao optimizers with configurations that use parameter groups

* use optimizer enum value

* fix validation
2025-08-01 17:54:04 -04:00
Wing Lian
cda3c82351 move ib/rdma libs into base image (#3002)
* move ib/rdma libs into base image

* use  --no-install-recommends
2025-08-01 16:10:37 -04:00
Wing Lian
7c3b428f23 Add validation for TP with models with tied embeddings (#2999)
* add validation for tp + tied embeddings models

* fix logic and messaging

* add additional guard for null tp size
2025-08-01 13:58:16 -04:00
Wing Lian
01a6bd1a0e use CCE fix for TP using vocab parallel for CEL (#3000) 2025-08-01 13:21:58 -04:00
NanoCode012
41709822a7 fix: move memory usage log to trainer.log (#2996) [skip ci] 2025-08-01 13:21:43 -04:00
Wing Lian
02a37199ee prevent empty value for vllm_mode (#2998) 2025-08-01 09:59:45 -04:00
NanoCode012
7026cd5e9e Feat: Add N-D parallelism docs (#2989)
* fix: remove non-existent file

* feat: add n-d parallel docs

* fix: comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-01 13:18:31 +07:00
NanoCode012
eb0a8a7775 feat: upgrade cce commit to include smollm3, granite, granitemoe (#2993) 2025-07-31 18:18:44 -04:00
salman
294c7fe7a6 Distributed/ND-Parallel (#2977) 2025-07-31 15:25:02 -04:00
Wing Lian
7b68dfafd7 jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler

var name fix
make sure to create scheduler first

* wire things together

* more fixes

* fix for nesting scheduler and first anneal phase

* no need for relora trainer anymore since we've generalized the relora scheduler

* remove redundant relora scheduler and lint

* update relora e2e test for updated params

* need restart steps for relora test

* update quarto docs for dropped relora trainer

* update example yaml

* drop verbose arg

* min lr scale support for jagged lr

* don't let min_lr be nonetype

* cleanup args
2025-07-31 13:50:03 -04:00
salman
32a7890231 Revert test update to index.qmd (#2995) [skip ci] 2025-07-31 11:46:31 -04:00
Wing Lian
563f5eed7a update dependencies - liger + trl (#2987)
* update dependencies

* set dataset processes for tests

* add support for GSPO
2025-07-31 11:17:17 -04:00
Wing Lian
6ec282094d actually call the register method on plugins (#2991) [skip ci] 2025-07-31 11:13:15 -04:00
salman
09dda462ab Fix don't preview docs for contributors (#2994) [skip ci]
* checking against fork vs. main repo

* force doc preview
2025-07-31 11:12:41 -04:00
Dan Saunders
bb1cae1a20 CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
2025-07-30 15:46:56 -04:00
Wing Lian
22810c97b7 use warmup_ratio as a better default than warmup steps since it's data dependent (#2897) [skip ci]
* use warmup_ratio as a better default than warmup steps since it's data dependent

* replace remainder of warmup_steps
2025-07-30 06:44:06 -04:00
Vincenzo di Cicco
2eb7ff95af Use '<|finetune_right_pad|>' as padding token for LLama4 (#2988) [skip ci] 2025-07-30 06:38:13 -04:00
NanoCode012
90e5598930 Feat: Add voxtral, magistral small 1.1, and misc gemma3n fixes (#2979)
* fix: lock version in gemma3n docs

* feat: add sample configs and docs

* chore: move mistraltokenizer into mistral folder

* feat: update instructions

* feat: add dynamic load voxtral

* fix: remove incorrect vision config, add audio

* fix: support voxtral processing strategy and address none in data

* feat: patch mistraltokenizer subclass upstream and add missing

* feat: update cce commit to include voxtral

* fix: remove old comment

* fix: gemma3 patch not needed anymore

* fix: voxtral modeling code

* fix: remove incorrect ds path

* fix: adjust apply chat template parsing

* feat: enable voxtral patch

* fix: patch

* feat: update example datasets

* fix: target layer

* feat: update gemma3n docs

* feat: update voxtral docs

* feat: revert assistant parsing to rely on new upstream changes

* chore: skip test till next PR fix

* fix: override upstream decode due to missing handling

* feat: update readme

* fix: update

* feat: add magistral small think support

* feat: update mistral-common dep

* fix: lint

* fix: remove optional dep

* chore: typing

* chore: simply import

* feat(doc): update differences for 2507

* fix: coderrabbit comments

* feat: update clarify docs on new transformers
2025-07-30 15:57:05 +07:00
Wing Lian
1d2aa1e467 upgrade to support latest transformers release (#2984)
* upgrade to support latest transformers release

* bump mistral common too

* Fix dependencies
2025-07-27 17:05:12 -04:00
NICOLAS BZRD
430be216d8 add shuffle_before_merging_datasets option to allow independent shuffling of datasets before merging (#2981) [skip ci] 2025-07-27 17:04:56 -04:00
Wing Lian
28804b82e4 don't create a reference model if grpo beta is 0.0 (#2983) [skip ci] 2025-07-27 17:04:42 -04:00
Wing Lian
add3e5076b don't publish to netlify on contributor submissions since it requires auth tokens (#2985) [skip ci]
* don't publish to netlify on contributor submissions since it requires auth tokens

* fix no-tmux build and add contact to motd
2025-07-27 17:04:27 -04:00
NanoCode012
41434f0c28 feat(doc): add all providers to readme (#2972) [skip ci]
* feat(doc): add vastai link

* feat: add cloud providers to readme for more visibility

* add prime intellect, remove Modal as sponsor

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-27 17:03:50 -04:00
Wing Lian
f7ea140838 TiledMLP support for FSDP2 (#2950)
* make TiledMLP work with FSDP

* cleanup/gc at start of train to prevent large VRAM spike

* chore: lint

* generic function for non-deepspeed training

* unify patch to fix imports

* update readme for ALST and add examples

* make deepspeed attribute on params check more robust

* update with new info from PR review
2025-07-25 07:15:03 -04:00
Wing Lian
460e0f9ed9 improve handling of file lock when content is empty (#2959) 2025-07-24 16:10:38 -04:00
Wing Lian
e80faea0db garbage collect on the end of the step if we're going to save a checkpoint (#2971) [skip ci] 2025-07-24 16:10:23 -04:00
Wing Lian
0ff2f172ef Act offload lora fix (#2928) [skip ci]
* fix activation offloading with lora

* update w e2e test

* add docs for error
2025-07-24 16:10:04 -04:00
salman
1407aac779 Skip CI for draft PRs (#2970) 2025-07-24 09:11:46 +01:00
Dan Saunders
b34c3371ed upgrade torchao (#2968) 2025-07-23 10:27:28 -04:00
Wing Lian
5f1a4306b0 don't check dataset labels during preprocess for GRPO (#2952) [skip ci]
* don't check dataset labels during preprocess for GRPO

* use enum check per PR feedback
2025-07-22 20:40:44 -04:00
Wing Lian
93709eb5ce handle refactor upstream for flash attention (#2966) 2025-07-22 20:40:04 -04:00
Dan Saunders
208fb7b8e7 basic torchao fp8 mixed precision training (#2926)
* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
2025-07-22 16:27:47 -04:00
Wing Lian
b86a1d47b0 we don't need to call check_dataset_labels when skip_prepare_dataset is set (#2962)
* we don't need to call check_dataset_labels when skip_prepare_dataset is set

* Fix actual bug and revert prior fix

* warn and early return instead of raising an error

* use error
2025-07-22 10:00:53 -04:00
NanoCode012
01d8175d48 fix: revert changing default optimizer to muon (#2965) [skip ci] 2025-07-22 10:00:30 -04:00
NanoCode012
631268a0ca revert renaming of deepspeed stage3 args that use auto (#2964) [skip ci]
* Revert "fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg…"

This reverts commit e207762928.

* don't revert the values that don't use 'auto'

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-22 09:59:47 -04:00
Wing Lian
3a208cfd84 Autocomplete axolotl CLI (#2955)
* static autocomplete script for axolotl cli

* use list of commands that should autocomplete yaml files

* make sure to chmod the autocomplete script as executable

* shellcheck and fix autocompletion of directory/sub-dirs

* more shellcheck fixes
2025-07-22 08:30:31 -04:00
github-actions[bot]
7267edc168 chore: update pre-commit hooks (#2954) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-07-22 08:30:00 -04:00
NanoCode012
dfba881e99 Feat: add gemma3n support (#2852)
* feat: add gemma3n cce

* feat: add sample config

* feat: add gemma3n multimodal mode

* feat: add audio example

* feat: support audio and return pixel values in collator

* feat: support unmask only assistant region (gemma3n for now)

* feat(doc): add notes for audio loading

* feat: add audio support for gemma3n

* feat: update examples

* feat: add gemma3n to the docs

* fix: add link at top

* feat(doc): clarify additional requirements

* fix: mllama missing aspect ratio

* fix: mllama need attention fixes for fa2

* Partially Revert "fix: mllama need attention fixes for fa2"

This reverts commit a0bfdd1777.

* fix: disable FA2 for mllama in vision mode

* feat: update configs to use proper attention

* fix: support other vision features

* feat(doc): clarify requirements for gemma3n
2025-07-22 16:52:15 +07:00
Wing Lian
d32058e149 include torchvision in build for upstream changes requiring it now (#2953) [skip ci] 2025-07-22 04:19:16 -04:00
NanoCode012
bc1076d8a2 fix: suppress warning if we enabled skip prepare (#2958) 2025-07-21 11:42:04 -04:00
Wing Lian
b7e8f66e5a upstream fixes in cce for dora and tensor paralel support (#2960) [skip ci] 2025-07-21 11:41:53 -04:00
Wing Lian
e207762928 fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg (#2956) [skip ci]
* fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg

* replace the rest of the migrated deepspeed params
2025-07-21 11:41:31 -04:00
Wing Lian
fefb0797ee better handling for reward function checks for GRPO (#2933) [skip ci]
* better handling for reward function checks for GRPO

* consolidate msg copy
2025-07-21 11:41:15 -04:00
Wing Lian
af8d257aa2 make pad_to_sequence_len default to the same value as sample_packing (#2941) [skip ci]
* make pad_to_sequence_len default to the same value as sample_packing

* remove duplicate validation

* fix test

* update description meta

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-07-21 11:40:56 -04:00
Wing Lian
db5f6f4693 limit num_proc when saving datasets to disk (#2948) [skip ci]
* limit num_proc when saving datasets to disk

* enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save

* update fixtures with dataset processes since that should never be NoneType

* improve reusability for tests
2025-07-21 11:39:38 -04:00
Wing Lian
8e5f146701 Fix cloud docker image build and remove apt files for optim (#2961)
* make sure to apt update to install sudo and tmux

* remove apt archives too
2025-07-21 11:05:00 -04:00
Wing Lian
31a15a49b6 add additional packages via apt for better multi-node support (#2949)
* cleanup in Dockerfile and add infiniband packages

* fixes for ci

* fix nightly too
2025-07-20 21:19:23 -04:00
NanoCode012
b986f7c7cb fix: return proper attention for llama4 lora kernel and fsdp2 llama4 example fix (#2943)
* fix: return proper attention for llama4 lora optim

* fix: update fsdp2 llama4 config
2025-07-19 13:54:43 -04:00
salman
e5734e5cf0 adding torchtitan link (#2945) [skip ci] 2025-07-19 13:54:14 -04:00
Wing Lian
109d9c7442 make the initial call to tokenizer.pad not spam the console (#2946) [skip ci]
* make the initial call to tokenizer.pad not spam the console

* add guard from feedback

* make another common console output less verbose

* more logging fixes
2025-07-19 13:53:35 -04:00
Wing Lian
170322a1f0 make sure log level is upper (#2934) 2025-07-17 15:32:55 -04:00
Wing Lian
5f5ae76213 add validation around cce + chunked_ce (#2932) [skip ci]
* add validation around cce + chunked_ce

* return on end of validation method
2025-07-17 15:32:38 -04:00
Wing Lian
a798975b7c coderabbit manual settings (#2940) [skip ci] 2025-07-17 15:32:16 -04:00
Wing Lian
d23f972602 use state for wandb in callbacks (#2930) [skip ci] 2025-07-17 15:31:56 -04:00
Wing Lian
8e41317250 don't use include_tokens_per_second for GRPO (#2931) [skip ci]
* don't use include_tokens_per_second for GRPO

* use blocklist instead
2025-07-17 15:31:21 -04:00
Varun Gumma
9f2bb188a4 Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918)
* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer

* removed limiters and fixed config variable name

* black lint

* chore: lint

* feat: update handling of dataset_processes

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-07-17 09:47:58 -04:00
Wing Lian
9dde9e1b71 misc fixes 202507 (#2937) [skip ci]
* misc fixes 202507

* manually handle attn class for llama4
2025-07-17 09:47:45 -04:00
Wing Lian
f2474ef941 bump accelerate to 1.9.0 (#2936) [skip ci] 2025-07-17 09:46:43 -04:00
Wing Lian
8a4bcacdb2 cu126-torch271 for cloud docker image should be tagged with main-latest (#2935) 2025-07-17 00:01:23 -04:00
Wing Lian
d2c3d5a954 run nightly-vs-upstream-main on 2.7.1 and multi-gpu also (#2929) [skip ci] 2025-07-16 21:45:42 -04:00
Wing Lian
36cbe13d18 activation offloading with cuda streams doesn't work with LoRA (#2927) 2025-07-16 11:59:20 -04:00
Wing Lian
2c408b5c5e Apply generic fused liger ce, cce, and tiledmlp for arbitrary models (#2908)
* Apply generic fused liger ce for unknown models

* fix deepseek liger modeling

* generic cce and config tiled mlp to use original mlp and auto detect compute params

* fix weight and lint

* update warnings

* address PR feedback

* use lookup for model class prefixes

* revert inadvertent change to flash attn verison

* remove un-needed pylint annotations

* fix import
2025-07-15 22:40:41 -04:00
Wing Lian
942005f526 use modal==1.0.2 for nightlies and for cli (#2925) [skip ci]
* use modal==1.0.2 for nightlies and for cli

* use latest cce fork for upstream changes

* increase timeout
2025-07-15 20:31:23 -04:00
Dan Saunders
10ba1622f7 checkpoint model on first step callback (#2906)
* checkpoint model on first step callback

* remove debug

* add test cases; update existing tests not to save on first step

* move test out of solo

* delete

* default to False

* typo
2025-07-15 15:00:48 -04:00
Wing Lian
d320ef6199 fix for upstream refactor of KwargsForCausalLM (#2911) 2025-07-15 11:28:41 -04:00
NanoCode012
354eaaf0d3 feat: add call method to mistral tokenizer wrapper (#2898) 2025-07-14 22:33:35 -04:00
greenhestu
a061446540 Fix: Prevents merging of tool arguments during preprocessing (#2909) 2025-07-14 22:33:10 -04:00
Wing Lian
cd079b5536 Tensor parallel w DeepSpeed AutoTP (#2574)
* support for deepspeed autotup

* bump to latest deepspeed that supports deepcompile too

* add deepcompile support too

* fix total steps calculation for TP

* setup fixture for tp

* update ds config to ensure weights are gathered for checkpoint

* fix duplicate validation names

* chore: lint
2025-07-14 21:33:48 -04:00
Wing Lian
5cc16040a8 move the plugin post trainer create to the setup trainer (#2907)
* move the plugin post trainer create to the setup trainer

* move post-train plugins to execute-training fn
2025-07-14 20:11:33 -04:00
Wing Lian
38359a8997 allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
2025-07-14 20:11:11 -04:00
Wing Lian
7dc3ac6cb3 update nightlies builds (#2921) [skip ci] 2025-07-14 20:10:43 -04:00
Wing Lian
99187cd208 Activation Offloading w CUDA Streams (#2900) [skip ci]
* use cuda streams for activation offloading

* use torch native ops

* update cfg schema for streams

* fix literal constructor for set

* use context for training step so it doesn't affect evals

* disable streams

* auto gc on eval steps

* use activation_offloading config arg

* add docs for gradient checkpointing

* handle validation for gc/ao

* use cuda streams for act offloading

* add more validation for AC w/o GC

* fix docs

* move activation_offloading lower in definition so it doesn't break args/kwargs

* fix kd due to import order
2025-07-14 20:10:20 -04:00
Wing Lian
aa684122f1 upgrade peft==0.16.0 and datasets==4.0.0 (#2917) [skip ci]
* upgrade peft to 0.16.0

* upgrade datasets to 4.0.0

* refactor dupes from merge/rebase

* fix check for fsdp1 + sharded_state_dict

* use full state dict for ci
2025-07-14 20:09:26 -04:00
Wing Lian
ca4d4ef793 don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing

* add e2e test to validate preprocess cli with deepspeed

* ignore duplicate code for cfg
2025-07-14 14:19:19 -04:00
Dan Saunders
37edbe4999 Remove extra torch.compile call (#2904)
* debug

* debug

* debug

* moving validation code to transformers

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint
2025-07-14 12:32:45 -04:00
Wing Lian
e581c15d40 refactor dupes from merge/rebase (#2919) [skip ci] 2025-07-14 10:05:26 -04:00
Wing Lian
af92151a7b FSDP2 fix validation and add tests (#2910)
* fix validation and add tests

* remove debugging and add more tests

* remove migrate_fsdp
2025-07-14 09:25:44 -04:00
Wing Lian
80dc4c261a fix xformers version for python 2.6 (#2916) [skip ci] 2025-07-14 09:24:29 -04:00
Wing Lian
7ccbbd8e77 upgrade liger to 0.6.0 (#2893) [skip ci] 2025-07-14 09:24:07 -04:00
Wing Lian
5081db7f8a upgrade trl==0.19.1 (#2892) [skip ci]
* upgrade trl==0.19.1

* add vllm for tests for grpo

* fixes to work with latest trl

* need data_parallel_size config too

* support for vllm_mode for server / colocate

* vllm settings for colocate

* relax vllm version

* bump min hf hub for latest vllm support

* add hints on string literal for vllm mode

* use latest transformers 4.53.2

* tweak acceptable loss on flaky test_ds_zero3_packed test

* don't run flaky vllm/grpo tests for now
2025-07-14 09:23:42 -04:00
Wing Lian
41664c7c4c fix ddp for incorrect steps (#2915)
* fix ddp for incorrect steps

* add test
2025-07-14 07:51:16 -04:00
376 changed files with 7444 additions and 2967 deletions

41
.axolotl-complete.bash Normal file
View File

@@ -0,0 +1,41 @@
#!/bin/bash
_axolotl_completions() {
local cur prev
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
# If we're completing the first argument (the command)
if [[ $COMP_CWORD -eq 1 ]]; then
mapfile -t COMPREPLY < <(compgen -W "delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train" -- "$cur")
return 0
fi
# Commands that should complete with directories and YAML files
local -a yaml_commands=("merge-sharded-fsdp-weights" "quantize" "vllm-serve" "evaluate" "inference" "merge-lora" "preprocess" "train")
# Check if previous word is in our list
if [[ " ${yaml_commands[*]} " =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then
# Use filename completion which handles directories properly
compopt -o filenames
mapfile -t COMPREPLY < <(compgen -f -- "$cur")
# Filter to only include directories and YAML files
local -a filtered=()
for item in "${COMPREPLY[@]}"; do
if [[ -d "$item" ]] || [[ "$item" == *.yaml ]] || [[ "$item" == *.yml ]]; then
filtered+=("$item")
fi
done
COMPREPLY=("${filtered[@]}")
return 0
fi
# Default: no completion
return 0
}
# Remove the -o nospace option - let filenames handle it
complete -F _axolotl_completions axolotl

16
.coderabbit.yaml Normal file
View File

@@ -0,0 +1,16 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
review_status: true
collapse_walkthrough: true
poem: false
sequence_diagrams: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true

View File

@@ -17,7 +17,7 @@ on:
jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
@@ -108,7 +108,7 @@ jobs:
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
strategy:

View File

@@ -3,6 +3,7 @@ on:
# check on PRs, and manual triggers
merge_group:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
@@ -16,6 +17,7 @@ jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

View File

@@ -87,7 +87,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -98,6 +97,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -21,7 +21,7 @@ concurrency:
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
strategy:
fail-fast: false
matrix:
@@ -36,10 +36,17 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -12,11 +12,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -60,15 +65,15 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -2,7 +2,7 @@ name: Preview
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, reopened]
types: [opened, synchronize, reopened, ready_for_review]
# Run the workflow only when one of these files changes
paths:
@@ -25,6 +25,7 @@ permissions:
jobs:
preview:
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: Check out repository
uses: actions/checkout@v4
@@ -52,6 +53,7 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
id: netlify
with:
publish-dir: './_site'

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update requirements.txt
run: |
@@ -92,7 +92,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
timeout-minutes: 120
needs: [pre-commit, pytest]
strategy:
@@ -106,6 +106,13 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -116,7 +123,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -130,3 +137,45 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 2
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -13,6 +13,7 @@ on:
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
@@ -34,6 +35,7 @@ jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
@@ -47,6 +49,7 @@ jobs:
pytest:
name: PyTest
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
# needs: [preload-cache]
strategy:
fail-fast: false
@@ -78,7 +81,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
@@ -121,6 +124,7 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
strategy:
fail-fast: false
matrix:
@@ -151,7 +155,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
@@ -185,7 +189,7 @@ jobs:
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
@@ -235,7 +239,7 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
@@ -289,6 +293,7 @@ jobs:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
if: ${{ !github.event.pull_request.draft }}
strategy:
fail-fast: false

View File

@@ -27,7 +27,7 @@ repos:
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
rev: v1.17.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -119,14 +119,15 @@ datasets:
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
| Option | Default | Description |
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration

View File

@@ -25,6 +25,8 @@
## 🎉 Latest Updates
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
@@ -79,6 +81,20 @@ docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
#### Cloud Providers
<details>
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme)
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
</details>
### Your First Fine-tune
```bash
@@ -120,12 +136,6 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📜 License

View File

@@ -35,25 +35,30 @@ quartodoc:
- cli.train
- cli.evaluate
- cli.args
- cli.art
- cli.checks
- cli.config
- cli.delinearize_llama4
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.quantize
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- cli.utils
- cli.utils.args
- cli.utils.fetch
- cli.utils.load
- cli.utils.sweeps
- cli.utils.train
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.mamba
- core.trainers.relora
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
@@ -268,6 +273,7 @@ website:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- section: "Advanced Features"
contents:
@@ -276,6 +282,8 @@ website:
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
@@ -19,5 +19,7 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
# Upload coverage to Codecov if CODECOV_TOKEN is available
if [ -n "$CODECOV_TOKEN" ]; then
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
fi

View File

@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -22,6 +22,7 @@ coverage:
only_pulls: true
flags: null
paths: null
informational: true
patch:
default:
# basic

View File

@@ -7,9 +7,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": "auto"

View File

@@ -7,9 +7,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -17,9 +17,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -13,9 +13,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -10,7 +10,9 @@ ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
@@ -23,17 +25,17 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
fi && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
git config --get remote.origin.fetch && \
git config --global credential.helper store
# helper for huggingface-login cli
RUN git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -16,12 +16,19 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y --no-install-recommends \
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
ibverbs-providers ibverbs-utils infiniband-diags \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
@@ -31,12 +38,14 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \

View File

@@ -22,18 +22,22 @@ RUN apt-get update \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge

View File

@@ -14,7 +14,10 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -9,13 +9,15 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \

View File

@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Launcher Arguments
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
```bash
# Pass torchrun arguments
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
# Pass accelerate arguments
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
```
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
## Command Reference
### fetch
@@ -80,7 +94,11 @@ axolotl train config.yml \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
axolotl train config.yml --launcher python
# Pass launcher-specific arguments using -- separator
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
axolotl evaluate config.yml
# Evaluation with launcher arguments
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
```
### lm-eval
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```

View File

@@ -136,3 +136,7 @@ description: Frequently asked questions
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.

View File

@@ -0,0 +1,29 @@
---
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
models by reducing the memory footprint and improving computational efficiency.
### Enabling Gradient Checkpointing
```yaml
gradient_checkpointing: true
```
### Enabling Activation Offloading
```yaml
gradient_checkpointing: true # required for activation offloading
activation_offloading: true
```
Activation offloading variants:
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
to overlap the communications and computations when offloading.
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.

View File

@@ -124,10 +124,13 @@ For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
- [Novita](https://novita.ai/gpus-console?templateId=311)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
### Google Colab {#sec-colab}

149
docs/mixed_precision.qmd Normal file
View File

@@ -0,0 +1,149 @@
---
title: "Mixed Precision Training"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:
- **FP16** - Half precision 16-bit (Pascal generation+)
- **BF16** - Brain Float 16-bit (Ampere generation+)
- **FP8** - 8-bit floating point (Hopper generation+)
## FP16 Mixed Precision {#sec-fp16}
### Overview {#sec-fp16-overview}
FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.
### Configuration {#sec-fp16-config}
```{.yaml}
fp16: true
```
### FP16 Considerations {#sec-fp16-considerations}
- May require gradient scaling to prevent underflow
- Less numerically stable than BF16
- Can cause training instability with some model architectures
- Consider using BF16 if your hardware supports it
## BF16 Mixed Precision {#sec-bf16}
### Overview {#sec-bf16-overview}
BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.
### Configuration {#sec-bf16-config}
```{.yaml}
# Automatic BF16 detection (recommended)
bf16: auto
# Or explicitly enable
bf16: true
# For evaluation with BF16
bf16: full # Equivalent to bf16_full_eval in the HF trainer
```
## FP8 Mixed Precision {#sec-fp8}
::: {.callout-note}
FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.
:::
### What is FP8? {#sec-fp8-overview}
FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy.
### Requirements {#sec-fp8-software}
- Hopper+ GPUs (H100/H200)
- PyTorch 2.7+ (+ compatible TorchAO version)
- CUDA 12.4+
### Configuration {#sec-fp8-config}
Add to your YAML config:
```{.yaml}
# Enable FP8 mixed precision
fp8: true
# Optional: Enable FP8 for FSDP all-gather operations
fp8_enable_fsdp_float8_all_gather: true
# Enable torch.compile (almost always necessary for FP8 speedups)
torch_compile: true
```
::: {.callout-important}
**torch.compile is critical for FP8 performance**
FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.
:::
### Advanced FP8 Configs {#sec-fp8-advanced}
For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:
```{.yaml}
fp8: true
fp8_enable_fsdp_float8_all_gather: true
torch_compile: true
# FSDP configuration
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```
## Best Practices {#sec-best-practices}
### Choosing Precision Format {#sec-choosing-format}
- **Start with automatic detection**: `bf16: auto`
- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed
- **For Ampere (A100/RTX 30/40)**: Use BF16
- **For older Pascal/Turing GPUs**: Use FP16 with caution
- **For very old or unsupported GPUs**: Use FP32
### Validation and Testing {#sec-validation}
Always validate your mixed precision setup:
- **Start with a small dataset** to verify stability
- **Monitor loss curves** for irregularities
- **Compare with FP32 baseline** when possible
- **Test evaluation metrics** match expectations
### FP8 Particulars {#sec-fp8-details}
- Use cases
- Single GPU training
- Multi GPU training with FSDP2 or Deepspeed
- Speedups
- Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings
- Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)
- Known issues:
- FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))
- FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training
- Flash Attention 2 does not play nicely with `torch.compile`
See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model
For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).

View File

@@ -98,8 +98,8 @@ fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
For example, if you were using the following FSDP1 config:
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config:
```{.yaml}
fsdp_version: 1

View File

@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
Run the following on each node:
### Option 1: New Axolotl CLI with launcher args (Recommended)
```bash
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
```
### Option 2: Direct torchrun (Legacy)
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
Please make sure to substitute the placeholder variables:
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -14,6 +14,7 @@ format:
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
@@ -110,6 +111,22 @@ base_model: google/gemma-3-4b-it
chat_template: gemma3
```
### Gemma-3n {#sec-gemma-3n}
::: {.callout-warning}
The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.
:::
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.17`
:::
```yaml
base_model: google/gemma-3n-E2B-it
chat_template: gemma3n
```
### Qwen2-VL {#sec-qwen2-vl}
```yaml
@@ -132,7 +149,9 @@ For multi-modal datasets, we adopt an extended `chat_template` format similar to
- A message is a list of `role` and `content`.
- `role` can be `system`, `user`, `assistant`, etc.
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`).
### Image
::: {.callout-note}
For backwards compatibility:
@@ -141,15 +160,29 @@ For backwards compatibility:
- If `content` is a string, it will be converted to a list with `type` as `text`.
:::
::: {.callout-tip}
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
- `"path": "/path/to/image.jpg"`
- `"url": "https://example.com/image.jpg"`
- `"base64": "..."`
- `"image": PIL.Image`
### Audio
For audio loading, you can use the following keys within `content` alongside `"type": "audio"`:
- `"path": "/path/to/audio.mp3"`
- `"url": "https://example.com/audio.mp3"`
- `"audio": np.ndarray`
::: {.callout-tip}
You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Example
Here is an example of a multi-modal dataset:
```json
[
@@ -178,3 +211,9 @@ Here is an example of a multi-modal dataset:
}
]
```
## FAQ
1. `PIL.UnidentifiedImageError: cannot identify image file ...`
`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server.

102
docs/nd_parallelism.qmd Normal file
View File

@@ -0,0 +1,102 @@
# N-D Parallelism
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
- A model's weights are too large to fit on a single GPU's memory.
- A model's activations, especially with very long contexts, are too large for a single GPU.
- You want to accelerate training by using multiple GPUs or nodes.
or combinations of the above!
## Core Concepts
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
### Data Parallelism {#sec-dp}
Data Parallelism focuses on splitting the global data batch across GPUs.
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
- FSDP maps to ZeRO stages:
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
### Context Parallelism (CP) {#sec-cp}
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
## Usage
```yaml
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
fsdp_version: 2
fsdp_config:
# ...
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 4
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 2
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
```
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
## Examples
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
2. FSDP + TP on a single 8-GPU node:
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
3. FSDP + CP on a single 8-GPU node for long context:
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
## Support Matrix
This matrix describes how different parallelism methods can be combined in Axolotl.
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
| --- | :---: | :---: |:---:|:---:|---|
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`
- `cp_size` refers to `context_parallel_size`

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

9
examples/alst/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Arctic Long Sequence Training (ALST)
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
techniques. It is a combination of:
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).

View File

@@ -0,0 +1,53 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,59 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,7 +66,7 @@ flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
save_total_limit:

View File

@@ -43,7 +43,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -47,7 +47,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -48,7 +48,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -47,7 +47,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -48,7 +48,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -47,7 +47,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -48,7 +48,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -54,7 +54,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1

View File

@@ -57,7 +57,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1

View File

@@ -41,7 +41,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -51,7 +51,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -47,7 +47,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -77,7 +77,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.000001

View File

@@ -44,7 +44,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -25,7 +25,7 @@ lora_target_linear: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -40,7 +40,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -41,7 +41,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -42,7 +42,7 @@ logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0001

View File

@@ -42,7 +42,7 @@ logging_steps: 1
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -50,7 +50,7 @@ logging_steps: 1
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -43,7 +43,7 @@ logging_steps: 1
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.1

View File

@@ -49,7 +49,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -49,7 +49,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -45,7 +45,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -48,7 +48,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -43,7 +43,7 @@ logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0001

View File

@@ -41,7 +41,7 @@ logging_steps: 1
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0

View File

@@ -16,7 +16,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
@@ -50,7 +50,7 @@ flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1

View File

@@ -19,7 +19,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -51,7 +51,7 @@ flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -19,7 +19,7 @@ lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -48,7 +48,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 20
warmup_ratio: 0.1
evals_per_epoch: 4
eval_steps:
saves_per_epoch: 4

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
@@ -49,7 +49,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: false
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -47,7 +47,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -38,7 +38,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -21,7 +21,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -49,7 +49,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -75,7 +75,7 @@ xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -20,7 +20,7 @@ special_tokens:
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
warmup_steps: 10
warmup_ratio: 0.1
# Iterations
num_epochs: 1

View File

@@ -27,7 +27,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -35,7 +35,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
]
},
{

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -51,8 +51,10 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -51,8 +51,10 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -12,7 +12,7 @@ output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -37,7 +37,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -30,7 +30,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -61,7 +61,7 @@ resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -25,7 +25,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -46,7 +46,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -31,7 +31,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -18,7 +18,7 @@ remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -35,7 +35,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -25,7 +25,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,65 @@
# Finetune Gemma-3n with Axolotl
Gemma-3n is a family of multimodal models from Google found on [HuggingFace](https://huggingface.co/collections/google/gemma-3n-685065323f5984ef315c93f4). This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:
```bash
pip3 install timm==1.0.17
# for loading audio data
pip3 install librosa==0.11.0
```
3. Run the finetuning example:
```bash
# text only
axolotl train examples/gemma3n/gemma-3n-e2b-qlora.yml
# text + vision
axolotl train examples/gemma3n/gemma-3n-e2b-vision-qlora.yml
# text + vision + audio
axolotl train examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml
```
Let us know how it goes. Happy finetuning! 🚀
WARNING: The loss and grad norm will be much higher than normal. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
### TIPS
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Gemma 3n Blog](https://ai.google.dev/gemma/docs/gemma-3n)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,74 @@
base_model: google/gemma-3n-E2B-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
load_in_8bit: false
load_in_4bit: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
split: train[:1%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# lora_target_linear: # Does not work with gemma3n currently
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,78 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
# sample dataset below requires downloading audio/image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
datasets:
- path: Nanobit/text-vision-audio-2k-test
type: chat_template
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,75 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

Some files were not shown because too many files have changed in this diff Show More