Compare commits

...

42 Commits

Author SHA1 Message Date
Wing Lian
b7ec06b8a1 Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]
* Add optional Axolotl MoRA/ReMoRA integration

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Isolate MoRA adapter behavior in plugin

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Constrain MoRA variants to supported enum values

* Keep MoRA validation out of core config

---------

Co-authored-by: Swarm <swarm@localhost>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-12 07:19:55 -04:00
Wing Lian
e2f01de0e8 Fix Axolotl ReLoRA optimizer reset scope (#3646)
* Fix Axolotl ReLoRA optimizer reset scope
* fix: make relora reset method honor relora_prune_ratio

When relora_prune_method='reset' and relora_prune_ratio is explicitly
set, the ratio was silently ignored and replaced with the hardcoded
_FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to
ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset
and 0.9 for other methods. reset_optimizer now uses the same random
pruning path for both 'random' and 'reset'.

Also consolidate three-layer default mismatch: schema default for
relora_prune_method is now 'magnitude' (single canonical source);
dataclass defaults for both fields changed to None to eliminate the
conflicting fallback layer.

Tests updated: removed the test case that verified the old broken
behavior (reset ignoring ratio), added two cases proving reset honors
the passed ratio. E2E reset fixture now uses ratio=0.5 to make it
unambiguous that the ratio is honored.

* Fix ReLoRA uint8 pruning regression

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-09 17:52:35 -04:00
thad0ctor
5352d41d32 feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
2026-05-05 11:25:39 -04:00
VED
c15f6cffe2 fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)
* memory clean up for fsdp full state dict

* Update src/axolotl/monkeypatch/accelerate/fsdp2.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2026-05-05 11:22:35 -04:00
Wing Lian
e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
Younes B
6136ae627b Fix: add bitnet config (#3636)
* add bitnet config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 12:30:56 -04:00
Younes B
e662972a29 Feat: Add bitnet integration (#3634)
* add bitnet

* switch to uv

* chore: liint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 11:25:02 -04:00
NanoCode012
ebbd7fa847 feat: Add Mistral Medium 3.5 (#3633)
* fix: clarify incompat

* fix: transformers api change upstream

* fix: add pre prop

* feat: add examples

* chore: cleanup

* chore: update readme
2026-04-29 22:46:51 +07:00
Wing Lian
ac77da96da use smaller pretrained models for ci (#3620) [skip ci]
* use smaller pretrained models for ci

* more steps for loss check

* fix tests

* more train steps

* fix losses
2026-04-27 13:22:56 -04:00
NanoCode012
798c8fba89 chore: update docker docs (#3623)
Some checks failed
Publish Docs / build-deploy (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / pre-commit (push) Has been cancelled
Tests Nightly against upstream main / Prefetch S3 once to prime the CDN cache (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, 1, 3.11, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, true, 1, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 130, 13.0.0, true, 1, 3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-multigpu-tests (<nil>, 128, 12.8.1, true, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (<nil>, 130, 13.0.0, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (fbgemm-gpu, 128, 12.8.1, 2, 3.11, 2.10.0) (push) Has been cancelled
Pre-commit auto-update / auto-update (push) Has been cancelled
2026-04-24 16:03:21 +07:00
NanoCode012
17fc747f99 fix: docker build failing (#3622)
* fix: uv leftover docs

* fix: docker build failing

* chore: doc

* fix: remove old pytorch build

* fix: stop recommend flash-attn optional, let transformers pull

* fix: remove ring flash attention from image

* fix: quotes [skip ci]

* chore: naming [skip ci]
2026-04-24 14:23:09 +07:00
Wing Lian
901f2356bc dpo collation/padding (#3601) [skip ci]
* fix dpo collation/padding

* fix DPO collator encoder-decoder pixel_values dtype and is_encoder_decoder detection

- Use float32 instead of LongTensor for _pixel_values in encoder-decoder branch
- Add missing padding_value case for _pixel_values in encoder-decoder branch
- Derive is_encoder_decoder from model config instead of hardcoding False
2026-04-23 14:49:52 -04:00
thad0ctor
1bf65c500e feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612) 2026-04-23 00:26:34 -04:00
brightwind26
bcbe049c21 Feat: add support for datasets with str saved messages field (#3607)
* feat: support datasets saved in str format

* add also str for tools

* format

* fix: address comments + add unit test

* format
2026-04-23 00:25:48 -04:00
Andrew Wu
90090fa9e8 DPO support loss types (#3566)
* Support loss_type/loss_weights DPO

* Validate dpo loss type/weights only set for dpo

* Tests: Update ipo tests to use new path

* Docs: Update docs for new ipo path

* PR fixes - typo/validation

* PR nit - warning

* chore: fix warnings arg

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-04-23 00:25:28 -04:00
Wing Lian
7420fd4de6 fix async prefetch with nemogym (#3606) 2026-04-22 09:05:46 -04:00
Wing Lian
05113bc91a train on remote compute using Tinker compatible APIs (#3614)
* train on remote compute using Tinker compatible APIs

* chore: lint

* fixes with latest hatchery changes

* chore: lint
2026-04-22 01:14:41 -04:00
thad0ctor
e562e149ce fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check… (#3611)
* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing

Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.

HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.

Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.

Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.

Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
  * threading.local() store with _get/_set_shared_kv_states helpers
  * _patch_decoder_layer_call(): monkeypatches
    Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
    and stash it in TLS before delegating to GradientCheckpointingLayer.
    The partial formed downstream no longer references the dict.
  * fused_forward reads TLS first, falls back to kwarg for callers that
    bypass the patched __call__ (e.g. direct attention invocation).
  * wired into patch_gemma4_fused_attn; idempotent via a sentinel.

TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.

Introduced by PR #3598 (commit b8358aa5).

* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__

  Overwrite the thread-local shared_kv_states store on every invocation
  (including with None) instead of only when the kwarg is present.

  The previous conditional write left stale dicts in TLS on any path that
  reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
  kwarg — e.g. generation, eval hooks, or future HF refactors that make
  the kwarg optional. fused_forward would then silently consume a prior
  step's K/V dict instead of falling back to its own kwarg path.

  Unconditional write makes the invariant in the surrounding comment
  ("TLS is overwritten on each new step's first decoder-layer call, so
  the previous step's dict is released promptly") actually hold.

  No behavior change for the training happy path, which always passes
  the kwarg. Addresses CodeRabbit review on PR #3611

* fix: swap threading.local() for module-level store so autograd worker   threads see shared_kv_states during backward recompute

Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:

PR 3611's TLS variant only works when recompute runs on the same thread
  that set TLS during forward. PyTorch's C++ autograd engine
  (_engine_run_backward) spawns per-device worker threads to dispatch
  backward, and HF-Trainer gradient_checkpointing (stock
  torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
  unpack_hook -> recompute_fn on those worker threads. TLS set on the main
  thread during forward is invisible there, so _get_shared_kv_states()
  returns None and the consumer-layer lookup crashes with
  "'NoneType' object is not subscriptable" at
  fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).

  A plain module-level dict is visible to all threads in the process.
  Lifecycle is identical: the slot is overwritten each forward, releasing
  the previous step's dict and allowing its K/V tensors to be GC'd, so
  the original VRAM-leak fix still holds under FSDP2 AC too.

* scope gemma4 shared_kv_states side channel to checkpointed training

Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.

Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments

* add gemma4 cross-thread visibility test for shared_kv_states store

Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error

* fix logger

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-21 17:49:58 -04:00
NanoCode012
9de5b76336 feat: move to uv first (#3545)
* feat: move to uv first

* fix: update doc to uv first

* fix: merge dev/tests into uv pyproject

* fix: update docker docs to match current config

* fix: migrate examples to readme

* fix: add llmcompressor to conflict

* feat: rec uv sync with lockfile for dev/ci

* fix: update docker docs to clarify how to use uv images

* chore: docs

* fix: use system python, no venv

* fix: set backend cpu

* fix: only set for installing pytorch step

* fix: remove unsloth kernel and installs

* fix: remove U in tests

* fix: set backend in deps too

* chore: test

* chore: comments

* fix: attempt to lock torch

* fix: workaround torch cuda and not upgraded

* fix: forgot to push

* fix: missed source

* fix: nightly upstream loralinear config

* fix: nightly phi3 long rope not work

* fix: forgot commit

* fix: test phi3 template change

* fix: no more requirements

* fix: carry over changes from new requirements to pyproject

* chore: remove lockfile per discussion

* fix: set match-runtime

* fix: remove unneeded hf hub buildtime

* fix: duplicate cache delete on nightly

* fix: torchvision being overridden

* fix: migrate to uv images

* fix: leftover from merge

* fix: simplify base readme

* fix: update assertion message to be clearer

* chore: docs

* fix: change fallback for cicd script

* fix: match against main exactly

* fix: peft 0.19.1 change

* fix: e2e test

* fix: ci

* fix: e2e test
2026-04-21 10:16:03 -04:00
Wing Lian
323da791eb bump transformers to 5.5.4 and trl to latest 1.1.0 (#3603)
* bump transformers to 5.5.4 and trl to latest 1.1.0

* more upgrades

* update peft too

* adapt lora_merge to peft 0.19 layer config API

PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.

* allow lora_dropout on mixed attention+MoE configs under peft 0.19

PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).

Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
2026-04-15 09:27:03 -04:00
NanoCode012
6990478163 fix: rename model to adapter_model for fsdp sharded final model (#3585)
* fix: rename model to adapter_model for fsdp sharded final model

* fix: follow upstream transformer shard size

* fix: handle multiple model files

* fix redundant condition, tighten to safetensors, keep shard size small

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:30 -04:00
ゆり
63a58cfec1 feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers

Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.

- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
  within sequence_len while preserving the full prompt, then decodes
  back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len

Closes #3547

* improve RL truncation strategy robustness and performance

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:10 -04:00
madScientist10
3985ec2f67 feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci]
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.

Made-with: Cursor
2026-04-12 20:50:37 -04:00
Joaquin Hui
a44edda6d7 Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:50:15 -04:00
Wing Lian
66c3e5a3fd better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5

* address issues from code review

* stricter efficient merges for dora since we now have meta model to reference
2026-04-12 10:57:45 -04:00
Wing Lian
b8358aa5ab [gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598) 2026-04-12 10:29:55 -04:00
Joaquin Hui
e079cf16a2 qwen3_5.jinja: handle list content on system messages (#3595) [skip ci]
* qwen3_5.jinja: handle list content on system messages

The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.

Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.

Fixes #3590

* Address pr for other content types

---------

Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 00:58:58 -04:00
Wing Lian
e2f69828d2 [fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci] 2026-04-11 20:22:35 -04:00
Wing Lian
122b50bad6 pre-cache the eot token ids rather than on each iteration (#3594) [skip ci] 2026-04-11 20:05:21 -04:00
Wing Lian
e77a185e86 upgrade transformers to use v5.5.3 (#3593) 2026-04-10 17:08:14 -04:00
Wing Lian
29fa4dedbb Gemma4 fixes and profiler (#3591) 2026-04-10 16:46:17 -04:00
Wing Lian
315cdeede9 handle trainable/masked spans in content and reasoning content (#3592) 2026-04-10 14:11:10 -04:00
NanoCode012
e7a6a5b529 fix: move warning after we've set any overrides (#3589) [skip ci] 2026-04-10 13:00:47 -04:00
NanoCode012
bfb4da1d25 fix: document jinja2 file path support (#3588) [skip ci] 2026-04-10 13:00:26 -04:00
floaty3
4dfa0a59b2 Add uninstall command to cut_cross_entropy import message (#3583) [skip ci] 2026-04-10 13:00:07 -04:00
Wing Lian
4ef608dda3 fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
2026-04-09 20:02:36 -07:00
NanoCode012
7daf7d96f1 fix: regex for unfrozen language tower (#3586) [skip ci]
* fix: regex for unfrozen language tower

* fix: other leftover regex
2026-04-08 08:18:11 -07:00
Wing Lian
7c56809c7f use vllm 0.19.0 for torch 2.10.0 (#3582) 2026-04-07 08:09:49 -07:00
NanoCode012
149178ddb7 chore: cleanup post release v0.16 (#3577)
* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-06 10:10:52 -07:00
NanoCode012
dc638e723f fix(config): add cce and liger to nemotron-h example (#3573) [skip ci] 2026-04-06 10:10:25 -07:00
Wing Lian
6f15da4cac make it easier for agents to discover docs (#3579) [skip ci]
* make it easier for agents to discover docs

* fixup pr comments
2026-04-06 10:00:55 -07:00
Maxime
900eec7988 Fix DO_NOT_TRACK not being correctly handled (#3580)
* Fix DO_NOT_TRACK not being correctly handled

* add unit tests and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-04 05:16:58 -04:00
435 changed files with 17067 additions and 3051 deletions

View File

@@ -31,7 +31,11 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
# Install axolotl + dev and test dependencies
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install
# test

View File

@@ -30,14 +30,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -168,14 +160,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -6,7 +6,7 @@ on:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
- 'pyproject.toml'
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"

View File

@@ -18,12 +18,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -180,12 +174,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
- "tests/e2e/multigpu/**.py"
- "pyproject.toml"
- ".github/workflows/multi-gpu-e2e.yml"
- "scripts/cutcrossentropy_install.py"
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
- "src/axolotl/utils/distributed.py"
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
@@ -33,19 +31,19 @@ jobs:
fail-fast: false
matrix:
include:
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
@@ -53,7 +51,6 @@ jobs:
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -75,7 +72,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -8,6 +8,9 @@ on:
permissions: {}
env:
UV_SYSTEM_PYTHON: "1"
jobs:
setup_release:
name: Create Release
@@ -41,11 +44,15 @@ jobs:
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: |
pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
uv pip install wheel packaging
uv pip install --no-build-isolation -e .
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Extract tag name
id: tag

View File

@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '.github/workflows/tests-nightly.yml'
- ".github/workflows/tests-nightly.yml"
permissions:
contents: read
env:
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
name: pre-commit
@@ -20,7 +23,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
cache: "pip" # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -61,36 +64,34 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Override with nightly HF packages
run: |
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
@@ -102,9 +103,6 @@ jobs:
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -136,7 +134,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -157,7 +154,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:

View File

@@ -6,21 +6,19 @@ on:
branches:
- "main"
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
@@ -33,6 +31,7 @@ permissions:
env:
TRANSFORMERS_IS_CI: "yes"
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
@@ -44,7 +43,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
cache: "pip" # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -73,7 +72,7 @@ jobs:
exclude:
- python_version: "3.14"
pytorch_version: "2.9.1"
timeout-minutes: 20
timeout-minutes: 25
steps:
- name: cleanup node
@@ -94,32 +93,25 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
@@ -188,38 +180,42 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
uv pip install packaging setuptools_scm build wheel psutil
python -m build --no-isolation --sdist
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Verify agent docs are discoverable
run: |
# Agent docs live in docs/agents/ (source of truth) and are resolved
# at runtime from the repo checkout or via `axolotl fetch docs`
axolotl agent-docs --list
axolotl agent-docs | grep -q "Fine-tuning framework"
axolotl agent-docs grpo | grep -q "GRPO"
axolotl agent-docs sft | grep -q "SFT"
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
- name: Show HF cache
run: hf cache ls
@@ -281,7 +277,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -302,7 +297,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
@@ -364,7 +359,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
axolotl fetch examples # Download example configs
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
axolotl agent-docs grpo # Topic-specific agent reference
axolotl config-schema # Dump config JSON schema
```
## Training Methods
@@ -23,7 +26,7 @@ axolotl fetch examples # Download example configs
| Method | Config Key | When to Use |
|--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
@@ -35,6 +38,8 @@ Agent-specific references:
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
## Config Pattern

View File

@@ -1,6 +1,7 @@
include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include VERSION
include src/axolotl/utils/chat_templates/templates/*.jinja
include AGENTS.md
recursive-include docs/agents *.md
recursive-include axolotl *.py

View File

@@ -29,6 +29,9 @@
## 🎉 Latest Updates
- 2026/04:
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
- 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
@@ -86,7 +89,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
@@ -95,11 +98,19 @@ Features:
### Installation
#### Using pip
```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# install uv if you don't already have it installed (restart shell after)
curl -LsSf https://astral.sh/uv/install.sh | sh
# change depending on system
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
@@ -110,7 +121,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
@@ -157,6 +168,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
## AI Agent Support
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
```bash
# Show overview and available training methods
axolotl agent-docs
# Topic-specific references
axolotl agent-docs sft # supervised fine-tuning
axolotl agent-docs grpo # GRPO online RL
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
axolotl agent-docs reward_modelling # outcome and process reward models
axolotl agent-docs pretraining # continual pretraining
axolotl agent-docs --list # list all topics
# Dump config schema for programmatic use
axolotl config-schema
axolotl config-schema --field adapter
```
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support

View File

@@ -1 +1 @@
0.16.0.dev0
0.16.2.dev0

View File

@@ -134,7 +134,6 @@ quartodoc:
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
@@ -312,6 +311,7 @@ website:
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- docs/1_58bit_finetuning.qmd
- docs/optimizations.qmd
- section: "Core Concepts"
@@ -327,7 +327,6 @@ website:
- section: "Advanced Features"
contents:
- docs/fsdp_qlora.qmd
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd

View File

@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
# Override with nightly HF packages for nightly builds
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
fi
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,54 +0,0 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,7 +1,7 @@
#!/bin/bash
set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
set -o pipefail
for i in 1 2 3; do

View File

@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -24,15 +24,15 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
fi && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -24,16 +24,15 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean

View File

@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -0,0 +1,70 @@
---
title: "1.58-bit Finetuning"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
::: {.callout-note}
LoRA is not supported for BitNet models
:::
## Installation
Install the `onebitllms` package before using this feature:
```bash
uv pip install onebitllms
```
Or from source:
```bash
uv pip install git+https://github.com/tiiuae/onebitllms
```
## Supported models
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
```bash
tiiuae/Falcon-E-3B-Base-prequantized
tiiuae/Falcon-E-1B-Base-prequantized
```
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
## Configuration
To enable 1.58-bit finetuning, set the following in your configuration file:
```yaml
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
use_onebitllms: true
```
::: {.callout-note}
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
:::
## Considerations after training
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
```bash
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
```
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
## Example Configuration
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.

View File

@@ -0,0 +1,198 @@
# Model Architectures — Agent Reference
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## VLM (Vision Language Model) Quick Start
All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```
### ScatterMoE Kernels
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
```
### Auto-detection
Axolotl auto-detects Gemma4 and applies:
- `use_reentrant: false` for gradient checkpointing
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
### Multi-GPU
| Strategy | Works? | Notes |
|----------|--------|-------|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
FSDP2 config:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
```
### MoE (26B-A4B)
- `enable_moe_block: true`, 256 experts, top-k routing
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
### Common issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
## Gemma 3
**Models**: `google/gemma-3-*`
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
- `use_reentrant: false` recommended
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin

View File

@@ -0,0 +1,181 @@
# New Model Support — Agent Reference
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.52.0 for SFT
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.52.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
from axolotl.cli.config import load_cfg
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.loaders.model import ModelLoader
cfg = load_cfg("your_config.yaml")
normalize_config(cfg)
prepare_plugins(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Forward pass on preprocessed data
model.train()
out = model(input_ids, labels=labels)
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 34x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
- Gemma3 → `Gemma3ForConditionalGeneration`
- Gemma4 → `Gemma4ForConditionalGeneration`
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
- LLaVA → `LlavaForConditionalGeneration`
### Why this matters
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|-----------|----------------------|--------------------------------|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
| PEFT LoRA | ✅ | May fail on custom layer types |
| HF Trainer label handling | ✅ | May need extra inputs |
### Required extra inputs
Multimodal models require special inputs during training even for text-only data:
| Model | Required Input | Value for Text-Only |
|-------|---------------|-------------------|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
### Custom layer types and PEFT
Vision towers often use custom module wrappers that PEFT doesn't support:
| Model | Custom Layer | Wraps | Fix |
|-------|-------------|-------|-----|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
## Sample Packing
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
```python
# From masking_utils.py:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
```
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
```python
# In compute_loss():
if (
self.args.sample_packing
and model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
```
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
### Models that DON'T need this fix
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
## Attention Backend Selection
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
## Where to Add Model-Specific Fixes
| What | Where | Example |
|------|-------|---------|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
| Example configs | `examples/<model>/` | Validated YAML |
| Config validation | `utils/schemas/validation.py` | Compatibility checks |

View File

@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo`
- Overfitting → `rl: ipo`
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
- VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto`

View File

@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
## Profiling
To profile training and identify optimization opportunities:
```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.
To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map

View File

@@ -3,28 +3,71 @@ title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
Axolotl routes attention via a single config field:
```yaml
sdp_attention: true
attn_implementation: <backend>
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
`attn_implementation` is passed through to `transformers` verbatim (via
`model.config._attn_implementation`). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
## Flash Attention
## Backends
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
| `attn_implementation` | Description |
|---|---|
| `eager` | Plain PyTorch attention. No packing support. |
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
| `xformers` | xFormers memory-efficient attention. |
| `sage` | SageAttention (QK int8 / PV fp16). |
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
set the canonical name above.
### Capability flags
Axolotl derives three boolean capability flags from `attn_implementation` and
exposes them on the validated config:
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
(everything except `eager` and `sdpa`).
These are **computed** — they cannot be overridden from YAML.
## Per-backend notes
### SDPA
Default PyTorch attention. See
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
```yaml
flash_attention: true
attn_implementation: sdpa
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Flash Attention
### Flash Attention 2
Axolotl supports FA2, FA3, and FA4. The best available version is used
automatically based on your installed packages and GPU.
```yaml
attn_implementation: flash_attention_2 # or flash_attention_3
```
#### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
@@ -39,23 +82,25 @@ Alternatively, try reinstall or downgrade a version.
:::
### Flash Attention 3
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
#### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
FA4 is still a pre-release on PyPI, so `--pre` is required:
```bash
pip install flash-attn-4
pip install --pre flash-attn-4
```
Or from source:
@@ -63,7 +108,6 @@ Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
@@ -86,93 +130,113 @@ and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.
Requirements: ROCm 6.0 and above. See
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
### Flex Attention
```yaml
flex_attention: true
# recommended
torch_compile: true
attn_implementation: flex_attention
torch_compile: true # recommended
```
::: {.callout-note}
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
We recommend using latest stable version of PyTorch for best performance.
### SageAttention
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
Requirements: Ampere, Ada, or Hopper GPUs.
```yaml
sage_attention: true
attn_implementation: sage
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
### xFormers
```yaml
xformers_attention: true
attn_implementation: xformers
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
Recommended for Turing GPUs or below (e.g. Colab T4).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
### Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
Planned for deprecation. Prefer one of the backends above.
:::
Requirements: LLaMA model architecture
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
patched to implement shifted-sparse attention. Does not support sample packing.
```yaml
flash_attention: true
s2_attention: true
attn_implementation: s2
```
::: {.callout-tip}
### FP8
No sample packing support!
torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
flash-attn with FA3. KV caching must be disabled.
```yaml
attn_implementation: fp8
```
### Hub kernels
```yaml
attn_implementation: kernels-community/flash-attn3
```
Passed through to `transformers`; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
`attn_uses_flash_lib=False`).
## Migrating from legacy boolean flags
The following legacy config fields are **deprecated** and will be removed in a
future release. Each emits a `DeprecationWarning` when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
| `sdp_attention: true` | `attn_implementation: sdpa` |
| `xformers_attention: true` | `attn_implementation: xformers` |
| `flex_attention: true` | `attn_implementation: flex_attention` |
| `sage_attention: true` | `attn_implementation: sage` |
| `s2_attention: true` | `attn_implementation: s2` |
| `eager_attention: true` | `attn_implementation: eager` |
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
::: {.callout-note}
Existing example configs under `examples/` still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.
:::

View File

@@ -108,6 +108,14 @@ datasets:
type: chat_template
```
::: {.callout-tip}
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
```yaml
chat_template_jinja: ./path/to/my_template.jinja2
```
:::
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
@@ -294,6 +302,113 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Content parts with per-part training control
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think step by step...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
}
]
}
```
The configuration is the same as standard `chat_template` — no extra fields needed:
```yaml
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
Each content part supports:
- `type`: `"text"` (required)
- `text`: the text value (also accepts `content` or `value` as the key)
- `train`: `true`/`false` (optional) — whether to train on this part
- `weight`: `0`/`1` (optional) — alternative to `train`
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
::: {.callout-warning title="Whitespace at part boundaries"}
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
**Split BEFORE spaces** (space goes with the next part):
```json
[
{"type": "text", "text": "Let me think...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
```
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
```json
[
{"type": "text", "text": "Let me think... ", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
```json
[
{"type": "text", "text": "Thinking:\n", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
:::
::: {.callout-note}
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
:::
##### Per-part training on reasoning_content
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": true}
]
}
]
}
```
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
::: {.callout-tip}
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
:::
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.

View File

@@ -76,8 +76,10 @@ datasets:
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
#### Remote Hosts
@@ -208,17 +210,18 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:
You will now be in the container. Next, install Axolotl with dev dependencies:
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
### Attach To Container

View File

@@ -6,23 +6,33 @@ format:
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
This section describes the different Docker images that are released by AxolotlAI at
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
### Switch to the `-uv` images
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
#### Tags format
@@ -32,8 +42,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu128-2.9.1`
- `main-base-py3.12-cu128-2.10.0`
- `main-base-py3.12-cu130-2.9.1`
- `main-base-py3.12-cu130-2.10.0`
## Main
@@ -41,11 +53,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
#### Tags format {#sec-main-tags}
@@ -53,7 +64,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
main-latest
# nightly build
@@ -71,12 +82,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu128-2.9.1`
- `main-py3.12-cu128-2.10.0`
- `main-py3.12-cu130-2.9.1`
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.12.0`
- `main-20260315-py3.11-cu128-2.9.1`
- `0.16.1`
## Cloud
@@ -90,11 +102,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
#### Tags format

View File

@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart
```
@@ -304,7 +304,7 @@ lora_alpha: 32
lora_target_linear: true
bf16: auto
flex_attention: true
attn_implementation: flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention

View File

@@ -57,7 +57,7 @@ description: Frequently asked questions
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**

View File

@@ -154,7 +154,7 @@ lr_scheduler: cosine
warmup_steps: 10
bf16: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.6.0
- PyTorch ≥2.9.1
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
## Installation {#sec-installation}
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
### Quick Install {#sec-uv}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
Install uv if not already installed:
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash}
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
uv pip install --no-build-isolation axolotl[deepspeed]
```
### Edge/Development Build {#sec-edge-build}
@@ -82,14 +48,16 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]'
```
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
```
For development with Docker:
@@ -106,12 +74,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl:main-latest
axolotlai/axolotl-uv:main-latest
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
@@ -122,7 +90,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Use `axolotlai/axolotl-cloud-uv:main-latest`
- Available on:
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
@@ -141,7 +109,7 @@ For providers supporting Docker:
### macOS {#sec-macos}
```{.bash}
pip3 install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
@@ -152,21 +120,44 @@ See @sec-troubleshooting for Mac-specific issues.
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Environment Managers {#sec-env-managers}
## Migrating from pip to uv {#sec-migrating}
### Conda/Pip venv {#sec-conda}
If you have an existing pip-based Axolotl installation, you can migrate to uv:
1. Install Python ≥3.11
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
```
```{.bash}
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[deepspeed]
```
## Using pip (Alternative) {#sec-pip}
If you are unable to install uv, you can still use pip directly.
::: {.callout-important}
Please make sure to have PyTorch installed before installing Axolotl with pip.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[deepspeed]'
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -8,6 +8,7 @@ format:
## Supported Models
- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-4 {#sec-gemma-4}
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::
::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}

View File

@@ -0,0 +1,84 @@
# Multimodal assistant-only loss masking
## Correct placement
```yaml
# Top-level: only train_on_inputs lives here.
train_on_inputs: false
datasets:
- path: data/train.jsonl
type: chat_template
roles_to_train: # per-dataset — this is what the MM scanner reads
- assistant
train_on_eos: turn # per-dataset — same
test_datasets:
- path: data/val.jsonl
type: chat_template
split: train
roles_to_train:
- assistant
train_on_eos: turn
```
## How to verify at runtime
`build_collator` logs the resolved knobs at INFO:
```text
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
```
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
scanner — check that they are under `datasets[0]`, not at the root.
Each verified strategy additionally logs its resolved boundary token ids at
strategy init (e.g. `<|turn>model``[105, 4368]`, `<turn|>``[106]` for
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
pad and media tokens are masked" one-shot warning instead, it is on the
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
(below) to activate masking. The strategies currently on this path are
listed in the audit table above under `fallback + warn`.
## Config-based override: `cfg.role_boundaries`
For the "unverified" strategies above, or for custom chat templates that
don't match a built-in strategy's markers, users can declare role boundaries
directly in YAML without subclassing:
```yaml
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
- role: user
start: "<|turn>user"
end: "<turn|>"
# Optional keys:
# include_start: false # default False
# include_end: true # default True, respects cfg.train_on_eos
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
# end: null # span runs to end of sequence
```
Semantics:
- `start` and `end` are literal strings; axolotl encodes them at strategy
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
resolved token-id sequences at INFO level.
- The special value `end: eos_token` is the portable way to express
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
the strategy's built-in declarations wholesale (partial overlays are
intentionally unsupported — they're hard to reason about at review time).
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
always a typo or leftover — honoring it literally would produce all-masked
labels and zero gradient, so it is treated the same as unset.
- `cfg.roles_to_train` still governs which declared roles contribute to
loss. You can declare `user` and `assistant` boundaries and set
`roles_to_train: ["assistant"]` to have the scanner correctly identify
user spans as masking boundaries without training on their content.
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
unencodable markers), not silently at loss-compute time.

View File

@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
*Note: You should only enable one attention backend.*
See [Attention](attention.qmd) for the full list of backends and the canonical values.
### LoRA Optimizations

View File

@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: ipo
rl: dpo
dpo_loss_type: ["ipo"]
```
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
### ORPO
@@ -1145,8 +1147,7 @@ datasets:
type: ebft_strided_structured.transform
split: train[:1%]
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
attn_implementation: flex_attention # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention

View File

@@ -20,6 +20,8 @@ examples:
title: Arcee AFM
# MistralAI
- name: mistral-medium-3_5
title: Mistral Medium 3.5
- name: ministral3/think
title: Ministral 3 Thinking
- name: ministral3/vision

View File

@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example

View File

@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
Reduces attention memory from O(n^2) to O(n):
```yaml
flash_attention: true
attn_implementation: flash_attention_2
```
### Step 6: Offload with DeepSpeed

View File

@@ -1,53 +0,0 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
::: {.callout-important}
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
:::
### Installation
The following will install the correct unsloth and extras from source.
```bash
python scripts/unsloth_install.py | sh
```
### Usage
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Run one of the finetuning examples below.
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
uv pip uninstall causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -50,8 +50,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -11,12 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
```
```bash
pip3 install . --no-build-isolation --no-deps
uv pip install . --no-build-isolation --no-deps
```
## Optimization Guides

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -13,12 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -59,8 +59,7 @@ gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
sdp_attention:
attn_implementation: flash_attention_2
flash_optimum:
gptq_groupsize:

View File

@@ -39,8 +39,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -43,8 +43,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -73,8 +73,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,8 +40,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,8 +36,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -37,8 +37,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,7 +39,6 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,7 +39,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -40,7 +40,6 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -38,7 +38,6 @@ tf32: true
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -44,7 +44,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_mlp: true

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
flash_attn_cross_entropy: false
flash_attn_rms_norm: true

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 0

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,7 +36,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -71,8 +71,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
sequence_len: 1024
bf16: auto
tf32: false
flash_attention: true
attn_implementation: flash_attention_2
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>"

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -36,12 +36,7 @@
"id": "msOCO4NRmRLa"
},
"outputs": [],
"source": [
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
},
{
"cell_type": "markdown",

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -15,9 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -26,7 +26,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
@@ -51,8 +50,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
attn_implementation: flash_attention_2
# scaling_softmax: true # needs flex_attention
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_accumulation_steps: 1
micro_batch_size: 1

Some files were not shown because too many files have changed in this diff Show More