Compare commits

..

30 Commits

Author SHA1 Message Date
c6da9b9e92 Update SETUP_MIAAI.md: add bare Ubuntu rebuild section (driver, packages, Ollama) 2026-05-13 21:33:02 +00:00
c7c4885369 Update SETUP_MIAAI.md: pre-training checklist, Ollama stop/start, verify script, corrected training time 2026-05-13 21:19:15 +00:00
981a13e110 Update human_chat_qlora.yml: working config for RTX 5080 (seq_len 2048, qlora, chat_template) 2026-05-13 18:59:19 +00:00
74f2263ac7 Update SETUP_MIAAI.md: bitsandbytes sm_120 patch, OOM fixes, working training config 2026-05-13 18:58:51 +00:00
8693a1f61b fix Dockerfile-base-next: cuda 12.8.2, miniforge, sm_120 2026-05-13 14:37:01 +00:00
71c6a56e7a switch to HQQ quantization to bypass bitsandbytes sm_120 issue 2026-05-13 13:55:52 +00:00
38adf5cd37 add trust_remote_code, explicit bfloat16 and bnb dtype settings 2026-05-13 13:32:46 +00:00
3f29fa017b replace Capybara with SlimOrca (compatible ShareGPT format) 2026-05-13 12:58:29 +00:00
c02a76f132 fix field_messages mapping for Capybara/OpenHermes ShareGPT format 2026-05-13 12:56:03 +00:00
b9ceebfe7e fix deprecated type:sharegpt and flash_attention config keys 2026-05-13 12:52:25 +00:00
e9a3fd483f add human-like QLoRA training config for Llama 3.1 8B 2026-05-13 12:50:35 +00:00
eadd15c960 note MAX_JOBS for flash-attn compile speed 2026-05-13 04:45:21 +00:00
396ce4a9dd add miaai environment setup guide 2026-05-13 04:16:03 +00:00
Wing Lian
b7ec06b8a1 Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]
* Add optional Axolotl MoRA/ReMoRA integration

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Isolate MoRA adapter behavior in plugin

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Constrain MoRA variants to supported enum values

* Keep MoRA validation out of core config

---------

Co-authored-by: Swarm <swarm@localhost>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-12 07:19:55 -04:00
Wing Lian
e2f01de0e8 Fix Axolotl ReLoRA optimizer reset scope (#3646)
* Fix Axolotl ReLoRA optimizer reset scope
* fix: make relora reset method honor relora_prune_ratio

When relora_prune_method='reset' and relora_prune_ratio is explicitly
set, the ratio was silently ignored and replaced with the hardcoded
_FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to
ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset
and 0.9 for other methods. reset_optimizer now uses the same random
pruning path for both 'random' and 'reset'.

Also consolidate three-layer default mismatch: schema default for
relora_prune_method is now 'magnitude' (single canonical source);
dataclass defaults for both fields changed to None to eliminate the
conflicting fallback layer.

Tests updated: removed the test case that verified the old broken
behavior (reset ignoring ratio), added two cases proving reset honors
the passed ratio. E2E reset fixture now uses ratio=0.5 to make it
unambiguous that the ratio is honored.

* Fix ReLoRA uint8 pruning regression

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-09 17:52:35 -04:00
thad0ctor
5352d41d32 feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
2026-05-05 11:25:39 -04:00
VED
c15f6cffe2 fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)
* memory clean up for fsdp full state dict

* Update src/axolotl/monkeypatch/accelerate/fsdp2.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2026-05-05 11:22:35 -04:00
Wing Lian
e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
Younes B
6136ae627b Fix: add bitnet config (#3636)
* add bitnet config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 12:30:56 -04:00
Younes B
e662972a29 Feat: Add bitnet integration (#3634)
* add bitnet

* switch to uv

* chore: liint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 11:25:02 -04:00
NanoCode012
ebbd7fa847 feat: Add Mistral Medium 3.5 (#3633)
* fix: clarify incompat

* fix: transformers api change upstream

* fix: add pre prop

* feat: add examples

* chore: cleanup

* chore: update readme
2026-04-29 22:46:51 +07:00
Wing Lian
ac77da96da use smaller pretrained models for ci (#3620) [skip ci]
* use smaller pretrained models for ci

* more steps for loss check

* fix tests

* more train steps

* fix losses
2026-04-27 13:22:56 -04:00
NanoCode012
798c8fba89 chore: update docker docs (#3623)
Some checks failed
Publish Docs / build-deploy (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / pre-commit (push) Has been cancelled
Tests Nightly against upstream main / Prefetch S3 once to prime the CDN cache (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, 1, 3.11, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, true, 1, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 130, 13.0.0, true, 1, 3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-multigpu-tests (<nil>, 128, 12.8.1, true, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (<nil>, 130, 13.0.0, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (fbgemm-gpu, 128, 12.8.1, 2, 3.11, 2.10.0) (push) Has been cancelled
Pre-commit auto-update / auto-update (push) Has been cancelled
2026-04-24 16:03:21 +07:00
NanoCode012
17fc747f99 fix: docker build failing (#3622)
* fix: uv leftover docs

* fix: docker build failing

* chore: doc

* fix: remove old pytorch build

* fix: stop recommend flash-attn optional, let transformers pull

* fix: remove ring flash attention from image

* fix: quotes [skip ci]

* chore: naming [skip ci]
2026-04-24 14:23:09 +07:00
Wing Lian
901f2356bc dpo collation/padding (#3601) [skip ci]
* fix dpo collation/padding

* fix DPO collator encoder-decoder pixel_values dtype and is_encoder_decoder detection

- Use float32 instead of LongTensor for _pixel_values in encoder-decoder branch
- Add missing padding_value case for _pixel_values in encoder-decoder branch
- Derive is_encoder_decoder from model config instead of hardcoding False
2026-04-23 14:49:52 -04:00
thad0ctor
1bf65c500e feat: add processor_kwargs YAML field forwarded to from_pretrained (#3612) 2026-04-23 00:26:34 -04:00
brightwind26
bcbe049c21 Feat: add support for datasets with str saved messages field (#3607)
* feat: support datasets saved in str format

* add also str for tools

* format

* fix: address comments + add unit test

* format
2026-04-23 00:25:48 -04:00
Andrew Wu
90090fa9e8 DPO support loss types (#3566)
* Support loss_type/loss_weights DPO

* Validate dpo loss type/weights only set for dpo

* Tests: Update ipo tests to use new path

* Docs: Update docs for new ipo path

* PR fixes - typo/validation

* PR nit - warning

* chore: fix warnings arg

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-04-23 00:25:28 -04:00
Wing Lian
7420fd4de6 fix async prefetch with nemogym (#3606) 2026-04-22 09:05:46 -04:00
Wing Lian
05113bc91a train on remote compute using Tinker compatible APIs (#3614)
* train on remote compute using Tinker compatible APIs

* chore: lint

* fixes with latest hatchery changes

* chore: lint
2026-04-22 01:14:41 -04:00
378 changed files with 10712 additions and 1956 deletions

View File

@@ -31,10 +31,11 @@ PRs are **greatly welcome**!
Please run below to setup env Please run below to setup env
```bash ```bash
# Install axolotl + dev and test dependencies from lockfile # Install axolotl + dev and test dependencies
export UV_TORCH_BACKEND=cu128 # or cu130 export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test uv venv --no-project --relocatable
source .venv/bin/activate source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install pre-commit install
# test # test

View File

@@ -30,14 +30,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -168,14 +160,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64" platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""

View File

@@ -18,12 +18,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
@@ -180,12 +174,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"

View File

@@ -72,7 +72,7 @@ jobs:
exclude: exclude:
- python_version: "3.14" - python_version: "3.14"
pytorch_version: "2.9.1" pytorch_version: "2.9.1"
timeout-minutes: 20 timeout-minutes: 25
steps: steps:
- name: cleanup node - name: cleanup node

View File

@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
| Method | Config Key | When to Use | | Method | Config Key | When to Use |
|--------|-----------|-------------| |--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning | | SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) | | DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels | | KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model | | ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) | | GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |

View File

@@ -29,6 +29,9 @@
## 🎉 Latest Updates ## 🎉 Latest Updates
- 2026/04:
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
- 2026/03: - 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45). - New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat). - [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).

273
SETUP_MIAAI.md Normal file
View File

@@ -0,0 +1,273 @@
# Axolotl Setup — miaai (RTX 5080, CUDA 13.2)
## System Info
- GPU: NVIDIA RTX 5080 (16GB VRAM, sm_120 / Blackwell)
- Driver: 580.126.09 — max CUDA 13.0 shown by nvidia-smi, but nvcc from conda is 13.2
- OS: Ubuntu 25.10 (Python 3.13 system — do NOT use system Python for ML)
- Axolotl repo: `/home/tocmo0nlord/axolotl` (branch: `activeblue/main`)
- Conda env: `axolotl` at `/opt/miniconda3/envs/axolotl`
---
## Starting from Bare Ubuntu 25.10
If rebuilding from scratch, complete these steps first before anything else.
### A. System packages
```bash
sudo apt update && sudo apt upgrade -y
sudo apt install -y \
build-essential cmake git curl wget \
python3-dev libssl-dev zlib1g-dev \
ca-certificates gnupg lsb-release
```
### B. NVIDIA driver (580.xx)
Ubuntu 25.10 is too new for NVIDIA's apt repo. Install via ubuntu-drivers:
```bash
sudo ubuntu-drivers autoinstall
sudo reboot
```
After reboot, verify:
```bash
nvidia-smi
# Must show: NVIDIA GeForce RTX 5080, Driver Version: 580.x
```
If ubuntu-drivers installs the wrong version, force the right one:
```bash
sudo apt install -y nvidia-driver-580
sudo reboot
```
### C. Install Ollama
```bash
curl -fsSL https://ollama.com/install.sh | sh
# Verify it's running
systemctl status ollama
```
### D. HuggingFace CLI
```bash
pip3 install huggingface_hub
huggingface-cli login
# Paste your HF token — required for gated models like meta-llama
```
Once steps AD are done, continue with the One-time Setup below.
---
## Pre-Training Checklist (every session)
```bash
# 1. Stop Ollama — if it receives a request mid-training it will compete for VRAM
sudo systemctl stop ollama
# 2. Activate conda env
export PATH="/opt/miniconda3/bin:$PATH"
conda activate axolotl
# 3. Set env vars
export CUDA_HOME=$CONDA_PREFIX
export PATH=$CUDA_HOME/bin:$PATH
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# 4. Confirm GPU is clear (should show no processes before training)
nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv
# 5. Go to axolotl directory
cd /home/tocmo0nlord/axolotl
```
## Run Training
```bash
axolotl train ~/human_chat_qlora.yml
```
## After Training
```bash
# Restart Ollama
sudo systemctl start ollama
# Test the adapter interactively
axolotl inference ~/human_chat_qlora.yml \
--lora-model-dir ~/outputs/llama31-8b-humanchat \
--prompter chat
# (Optional) Merge adapter into base model for standalone deployment
axolotl merge-lora ~/human_chat_qlora.yml
```
---
## One-time Setup (fresh machine — after bare Ubuntu steps above)
### 1. Install Miniconda
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
bash miniconda.sh -b -p /opt/miniconda3
/opt/miniconda3/bin/conda init bash
source ~/.bashrc
```
### 2. Create Python 3.11 environment
```bash
conda create -n axolotl python=3.11 -y
conda activate axolotl
```
### 3. Clone axolotl repo
```bash
git clone https://git.activeblue.net/tocmo0nlord/axolotl.git /home/tocmo0nlord/axolotl
cd /home/tocmo0nlord/axolotl
git remote add upstream https://github.com/axolotl-ai-cloud/axolotl.git
git fetch upstream
git rebase upstream/main # keeps activeblue patches on top
```
### 4. Install CUDA toolkit (needed to compile flash-attn and bitsandbytes)
```bash
conda install -y -c "nvidia/label/cuda-12.8.0" cuda-toolkit
export CUDA_HOME=$CONDA_PREFIX
export PATH=$CUDA_HOME/bin:$PATH
```
> NOTE: Despite installing from the cuda-12.8.0 channel, conda resolves nvcc to **13.2.78**.
> This is fine — use cu132 everywhere to match.
### 5. Install PyTorch — use cu132 (matches nvcc from conda)
```bash
# torchaudio has no cu132 wheel — skip it, not needed for LLM training
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu132
python -c "import torch; print('CUDA:', torch.version.cuda); print('GPU:', torch.cuda.get_device_name(0))"
```
### 6. Install Axolotl
```bash
cd /home/tocmo0nlord/axolotl
pip install -e "."
```
### 7. Install flash-attn
> Compiles CUDA kernels from source — takes 1525 min on 10 cores of i7-14700K.
```bash
MAX_JOBS=10 pip install flash-attn --no-build-isolation
```
### 8. Compile bitsandbytes from source for sm_120 (RTX 5080 / Blackwell)
Prebuilt wheels do not include sm_120. CUDA 13.2 also dropped sm_5053.
Must compile from source with a patched CMakeLists.txt.
```bash
# Clone bitsandbytes v0.49.1
git clone --branch v0.49.1 --depth 1 \
https://github.com/bitsandbytes-foundation/bitsandbytes.git /tmp/bnb_0491
# Patch CMakeLists.txt: insert sm_120 override before the foreach loop
# (cmake >= 3.23.0 uses its own built-in arch list which does not include sm_120)
sed -i '/ foreach(capability \${CMAKE_CUDA_ARCHITECTURES_ALL})/i\ # RTX 5080 sm_120 patch\n set(CMAKE_CUDA_ARCHITECTURES_ALL 120)' /tmp/bnb_0491/CMakeLists.txt
# Verify patch landed correctly — set() line must appear immediately before foreach
grep -n "ARCHITECTURES_ALL\|foreach" /tmp/bnb_0491/CMakeLists.txt | tail -5
# Configure — must point cmake at conda's nvcc explicitly
cmake \
-DCMAKE_CUDA_COMPILER=/opt/miniconda3/envs/axolotl/bin/nvcc \
-DCOMPUTE_BACKEND=cuda \
-S /tmp/bnb_0491 \
-B /tmp/bnb_0491/build 2>&1 | grep -E "(Capabilit|CUDA Ver|Error)"
# Must show: CUDA Capabilities Selected: 120
# Build (adjust -j to your CPU core count)
cmake --build /tmp/bnb_0491/build -j10
# Install into conda site-packages
cp -r /tmp/bnb_0491/bitsandbytes \
/opt/miniconda3/envs/axolotl/lib/python3.11/site-packages/
# Verify CUDA works
python3 -c "
import torch, bitsandbytes as bnb
x = torch.randn(64, 64, device='cuda')
l = bnb.nn.Linear8bitLt(64, 64).cuda()
print('bitsandbytes CUDA OK:', l(x).shape)
"
```
### 9. Copy training config to home
```bash
cp /home/tocmo0nlord/axolotl/human_chat_qlora.yml ~/human_chat_qlora.yml
```
### 10. Verify the full stack
```bash
python3 -c "
import torch, bitsandbytes as bnb, flash_attn, transformers
print('torch :', torch.__version__, '| CUDA:', torch.version.cuda)
print('bitsandbytes:', bnb.__version__)
print('flash_attn :', flash_attn.__version__)
print('transformers:', transformers.__version__)
print('GPU :', torch.cuda.get_device_name(0))
print('VRAM :', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')
"
```
Expected output:
```
torch : 2.x.x | CUDA: 13.2
bitsandbytes: 0.50.0.dev0
flash_attn : 2.x.x
transformers: 5.x.x
GPU : NVIDIA GeForce RTX 5080
VRAM : 16.3 GB
```
---
## Training Config — human_chat_qlora.yml
Key settings tuned for RTX 5080 (16GB):
| Setting | Value | Notes |
|---|---|---|
| `sequence_len` | `2048` | 4096 OOMs during loss computation (logits x 128k vocab) |
| `micro_batch_size` | `1` | Effective batch = micro x grad_accum = 8 |
| `gradient_accumulation_steps` | `8` | Keeps effective batch size at 8 |
| `adapter` | `qlora` | 4-bit via bitsandbytes compiled from source |
| `attn_implementation` | `flash_attention_2` | Not the deprecated `flash_attention: true` |
| `type` (datasets) | `chat_template` | Not the deprecated `sharegpt` |
Expected training metrics (RTX 5080, ~65k samples, 2 epochs):
- VRAM: ~1011 GB active, ~11 GB allocated
- Training duration: ~3.5 hours
- Initial eval loss: ~0.81, perplexity ~2.25
- Final loss target: ~0.550.60
To push VRAM to ~14GB and improve training: set `micro_batch_size: 2` and `gradient_accumulation_steps: 4`.
---
## Common Pitfalls
| Problem | Cause | Fix |
|---|---|---|
| `externally-managed-environment` | System Python 3.13 blocks pip | Use conda env, never system pip |
| `No module named torch` (flash-attn) | pip builds in isolated env | Use `--no-build-isolation` |
| `CUDA_HOME not set` | CUDA toolkit not installed | `conda install cuda-toolkit` from nvidia channel |
| `CUDA version mismatch 13.2 vs 12.8` | Conda nvcc is 13.2, torch was cu128 | Reinstall torch with `--index-url .../cu132` |
| `torchaudio` not found for cu132 | No cu132 wheel exists | Skip torchaudio — not needed |
| flash-attn compile is slow | Single-threaded by default | Set `MAX_JOBS=<cpu_count>` before pip install |
| `nvcc fatal: Unsupported gpu architecture 'compute_50'` | bitsandbytes CMakeLists.txt hardcodes sm_50; CUDA 13.2 dropped it | Patch CMakeLists.txt (see step 8 above) |
| `CUDA Capabilities Selected: 50;52;...` ignores -D flag | cmake >= 3.23 built-in arch list lacks sm_120; CMakeLists.txt overrides -D | Insert `set(CMAKE_CUDA_ARCHITECTURES_ALL 120)` before foreach loop |
| `BackendUnavailable: scikit_build_core` | pip install of bnb triggers cmake rebuild | Copy .so directly to site-packages instead |
| `torch.OutOfMemoryError` during eval | logits tensor (batch x 4096 x 128k vocab) too large | Set `sequence_len: 2048`, `micro_batch_size: 1` |
| `type: sharegpt` deprecation warning | axolotl removed sharegpt type | Use `type: chat_template` with field mappings |
| `flash_attention: true` deprecation | Old config key removed | Use `attn_implementation: flash_attention_2` |
| Capybara dataset `field_messages null` | Capybara uses input/output format, not conversations | Switch to SlimOrca or OpenHermes-2.5 |
| Ollama loads model mid-training | Ollama is enabled and receives a request | `sudo systemctl stop ollama` before training |
| Training much slower than eval speed | The fast it/s on screen is the eval loop (forward only) | Normal — training includes backward pass and optimizer (~3.5h total) |
| ubuntu-drivers installs wrong NVIDIA version | Multiple driver candidates available | Force with `apt install nvidia-driver-580` |

View File

@@ -1 +1 @@
0.16.0.dev0 0.16.2.dev0

View File

@@ -311,6 +311,7 @@ website:
- docs/dataset_loading.qmd - docs/dataset_loading.qmd
- docs/qat.qmd - docs/qat.qmd
- docs/quantize.qmd - docs/quantize.qmd
- docs/1_58bit_finetuning.qmd
- docs/optimizations.qmd - docs/optimizations.qmd
- section: "Core Concepts" - section: "Core Concepts"

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \ RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ BASE_EXTRAS="optimizers,ray"; \
else \ else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \ BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \ fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \ if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -1,16 +1,15 @@
ARG CUDA_VERSION="12.8.1" ARG CUDA_VERSION="12.8.2"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:12.8.2-devel-ubuntu22.04 AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniforge3/bin:${PATH}"
ARG PYTHON_VERSION="3.11" ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next" ARG PYTORCH_VERSION="next"
ARG CUDA="128" ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 12.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
@@ -18,13 +17,13 @@ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \ && bash Miniforge3-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniforge3-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" && /root/miniforge3/bin/conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" ENV PATH="/root/miniforge3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 # If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \ RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ BASE_EXTRAS="optimizers,ray"; \
else \ else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \ BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \ fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \ if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
RUN if [ "$TARGETARCH" = "amd64" ]; then \ RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \ MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -0,0 +1,70 @@
---
title: "1.58-bit Finetuning"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
::: {.callout-note}
LoRA is not supported for BitNet models
:::
## Installation
Install the `onebitllms` package before using this feature:
```bash
uv pip install onebitllms
```
Or from source:
```bash
uv pip install git+https://github.com/tiiuae/onebitllms
```
## Supported models
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
```bash
tiiuae/Falcon-E-3B-Base-prequantized
tiiuae/Falcon-E-1B-Base-prequantized
```
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
## Configuration
To enable 1.58-bit finetuning, set the following in your configuration file:
```yaml
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
use_onebitllms: true
```
::: {.callout-note}
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
:::
## Considerations after training
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
```bash
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
```
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
## Example Configuration
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.

View File

@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
| Backend | Config | head_dim limit | torch_compile | Notes | | Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------| |---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported | | FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ | | FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback | | SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | | flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works | | eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class. **Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.

View File

@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
1. Paired preference data (chosen + rejected)? 1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo` - Default → `rl: dpo`
- Overfitting → `rl: ipo` - Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
- VRAM-limited → `rl: orpo` (no ref model) - VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model) - Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto` 2. Only binary labels (good/bad)? → `rl: kto`

View File

@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| Issue | Fix | | Issue | Fix |
|-------|-----| |-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` | | OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` | | `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly | | Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels | | Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples | | Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |

View File

@@ -3,28 +3,71 @@ title: Attention
description: Supported attention modules in Axolotl description: Supported attention modules in Axolotl
--- ---
## SDP Attention Axolotl routes attention via a single config field:
This is the default built-in attention in PyTorch.
```yaml ```yaml
sdp_attention: true attn_implementation: <backend>
``` ```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) `attn_implementation` is passed through to `transformers` verbatim (via
`model.config._attn_implementation`). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
## Flash Attention ## Backends
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically | `attn_implementation` | Description |
based on your installed packages and GPU. |---|---|
| `eager` | Plain PyTorch attention. No packing support. |
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
| `xformers` | xFormers memory-efficient attention. |
| `sage` | SageAttention (QK int8 / PV fp16). |
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
set the canonical name above.
### Capability flags
Axolotl derives three boolean capability flags from `attn_implementation` and
exposes them on the validated config:
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
(everything except `eager` and `sdpa`).
These are **computed** — they cannot be overridden from YAML.
## Per-backend notes
### SDPA
Default PyTorch attention. See
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
```yaml ```yaml
flash_attention: true attn_implementation: sdpa
``` ```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/) ### Flash Attention
### Flash Attention 2 Axolotl supports FA2, FA3, and FA4. The best available version is used
automatically based on your installed packages and GPU.
```yaml
attn_implementation: flash_attention_2 # or flash_attention_3
```
#### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported) Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
@@ -39,23 +82,25 @@ Alternatively, try reinstall or downgrade a version.
::: :::
### Flash Attention 3 #### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended) Requirements: Hopper only and CUDA 12.8 (recommended)
```bash ```bash
git clone https://github.com/Dao-AILab/flash-attention.git git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper cd flash-attention/hopper
python setup.py install python setup.py install
``` ```
### Flash Attention 4 #### Flash Attention 4
Requirements: Hopper or Blackwell GPUs Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
FA4 is still a pre-release on PyPI, so `--pre` is required:
```bash ```bash
pip install flash-attn-4 pip install --pre flash-attn-4
``` ```
Or from source: Or from source:
@@ -63,7 +108,6 @@ Or from source:
```bash ```bash
git clone https://github.com/Dao-AILab/flash-attention.git git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute cd flash-attention/flash_attn/cute
pip install -e . pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4. # FA2's flash_attn package includes a cute/ stub that shadows FA4.
@@ -86,93 +130,113 @@ and falls back to FA2/3.
::: :::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD ### AMD
Requirements: ROCm 6.0 and above. Requirements: ROCm 6.0 and above. See
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support). ### Flex Attention
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml ```yaml
flex_attention: true attn_implementation: flex_attention
torch_compile: true # recommended
# recommended
torch_compile: true
``` ```
::: {.callout-note} Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
We recommend using latest stable version of PyTorch for best performance. ### SageAttention
::: Requirements: Ampere, Ada, or Hopper GPUs.
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml ```yaml
sage_attention: true attn_implementation: sage
``` ```
Requirements: Ampere, Ada, or Hopper GPUs
```bash ```bash
pip install sageattention==2.2.0 --no-build-isolation pip install sageattention==2.2.0 --no-build-isolation
``` ```
::: {.callout-warning} ::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198). Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
::: :::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention) For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
::: {.callout-note} ### xFormers
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml ```yaml
xformers_attention: true attn_implementation: xformers
``` ```
::: {.callout-tip} ::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab). Recommended for Turing GPUs or below (e.g. Colab T4).
::: :::
For more details: [xFormers](https://github.com/facebookresearch/xformers) ### Shifted Sparse Attention
## Shifted Sparse Attention
::: {.callout-warning} ::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above. Planned for deprecation. Prefer one of the backends above.
::: :::
Requirements: LLaMA model architecture Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
patched to implement shifted-sparse attention. Does not support sample packing.
```yaml ```yaml
flash_attention: true attn_implementation: s2
s2_attention: true
``` ```
::: {.callout-tip} ### FP8
No sample packing support! torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
flash-attn with FA3. KV caching must be disabled.
```yaml
attn_implementation: fp8
```
### Hub kernels
```yaml
attn_implementation: kernels-community/flash-attn3
```
Passed through to `transformers`; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
`attn_uses_flash_lib=False`).
## Migrating from legacy boolean flags
The following legacy config fields are **deprecated** and will be removed in a
future release. Each emits a `DeprecationWarning` when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
| `sdp_attention: true` | `attn_implementation: sdpa` |
| `xformers_attention: true` | `attn_implementation: xformers` |
| `flex_attention: true` | `attn_implementation: flex_attention` |
| `sage_attention: true` | `attn_implementation: sage` |
| `s2_attention: true` | `attn_implementation: s2` |
| `eager_attention: true` | `attn_implementation: eager` |
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
::: {.callout-note}
Existing example configs under `examples/` still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.
::: :::

View File

@@ -77,8 +77,9 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash ```bash
export UV_TORCH_BACKEND=cu128 # or cu130 export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test uv venv --no-project --relocatable
source .venv/bin/activate source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
``` ```
#### Remote Hosts #### Remote Hosts
@@ -218,8 +219,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
You will now be in the container. Next, install Axolotl with dev dependencies: You will now be in the container. Next, install Axolotl with dev dependencies:
```bash ```bash
uv sync --extra flash-attn --extra deepspeed --group dev --group test uv venv --no-project --relocatable
source .venv/bin/activate source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
``` ```
### Attach To Container ### Attach To Container

View File

@@ -10,13 +10,16 @@ This section describes the different Docker images that are released by AxolotlA
[Docker Hub](https://hub.docker.com/u/axolotlai). [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important} ::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8. ### Switch to the `-uv` images
:::
::: {.callout-tip} Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name (e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments. same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
::: :::
## Base ## Base
@@ -85,7 +88,7 @@ Tags examples:
- `main-py3.12-cu130-2.10.0` - `main-py3.12-cu130-2.10.0`
- `main-latest` - `main-latest`
- `main-20260315-py3.11-cu128-2.9.1` - `main-20260315-py3.11-cu128-2.9.1`
- `0.12.0` - `0.16.1`
## Cloud ## Cloud

View File

@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
max_steps: 20 max_steps: 20
learning_rate: 5.0e-6 learning_rate: 5.0e-6
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart output_dir: ./outputs/ebft-quickstart
``` ```
@@ -304,7 +304,7 @@ lora_alpha: 32
lora_target_linear: true lora_target_linear: true
bf16: auto bf16: auto
flex_attention: true attn_implementation: flex_attention
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention use_reentrant: true # Required with flex_attention

View File

@@ -57,7 +57,7 @@ description: Frequently asked questions
**Q: vLLM is not working with Axolotl** **Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag. > A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4** **Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**

View File

@@ -154,7 +154,7 @@ lr_scheduler: cosine
warmup_steps: 10 warmup_steps: 10
bf16: true bf16: true
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU - NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11 - Python ≥3.11
- PyTorch ≥2.9.0 - PyTorch ≥2.9.1
## Installation {#sec-installation} ## Installation {#sec-installation}
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install: Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash} ```{.bash}
export UV_TORCH_BACKEND=cu128 # or cu130 export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable uv venv
source .venv/bin/activate source .venv/bin/activate
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed] uv pip install --no-build-isolation axolotl[deepspeed]
``` ```
### Edge/Development Build {#sec-edge-build} ### Edge/Development Build {#sec-edge-build}
@@ -49,12 +49,11 @@ For the latest features between releases:
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
export UV_TORCH_BACKEND=cu128 # or cu130 export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed uv venv
source .venv/bin/activate source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]'
``` ```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker} ### Docker {#sec-docker}
```{.bash} ```{.bash}
@@ -132,11 +131,11 @@ source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start) # Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130 export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable uv venv
source .venv/bin/activate source .venv/bin/activate
# Reinstall axolotl # Reinstall axolotl
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed] uv pip install --no-build-isolation axolotl[deepspeed]
``` ```
## Using pip (Alternative) {#sec-pip} ## Using pip (Alternative) {#sec-pip}
@@ -151,13 +150,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
```{.bash} ```{.bash}
pip3 install -U packaging setuptools wheel ninja pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[deepspeed]
``` ```
For editable/development installs: For editable/development installs:
```{.bash} ```{.bash}
pip3 install -U packaging setuptools wheel ninja pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[deepspeed]'
``` ```
## Troubleshooting {#sec-troubleshooting} ## Troubleshooting {#sec-troubleshooting}

View File

@@ -0,0 +1,84 @@
# Multimodal assistant-only loss masking
## Correct placement
```yaml
# Top-level: only train_on_inputs lives here.
train_on_inputs: false
datasets:
- path: data/train.jsonl
type: chat_template
roles_to_train: # per-dataset — this is what the MM scanner reads
- assistant
train_on_eos: turn # per-dataset — same
test_datasets:
- path: data/val.jsonl
type: chat_template
split: train
roles_to_train:
- assistant
train_on_eos: turn
```
## How to verify at runtime
`build_collator` logs the resolved knobs at INFO:
```text
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
```
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
scanner — check that they are under `datasets[0]`, not at the root.
Each verified strategy additionally logs its resolved boundary token ids at
strategy init (e.g. `<|turn>model``[105, 4368]`, `<turn|>``[106]` for
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
pad and media tokens are masked" one-shot warning instead, it is on the
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
(below) to activate masking. The strategies currently on this path are
listed in the audit table above under `fallback + warn`.
## Config-based override: `cfg.role_boundaries`
For the "unverified" strategies above, or for custom chat templates that
don't match a built-in strategy's markers, users can declare role boundaries
directly in YAML without subclassing:
```yaml
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
- role: user
start: "<|turn>user"
end: "<turn|>"
# Optional keys:
# include_start: false # default False
# include_end: true # default True, respects cfg.train_on_eos
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
# end: null # span runs to end of sequence
```
Semantics:
- `start` and `end` are literal strings; axolotl encodes them at strategy
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
resolved token-id sequences at INFO level.
- The special value `end: eos_token` is the portable way to express
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
the strategy's built-in declarations wholesale (partial overlays are
intentionally unsupported — they're hard to reason about at review time).
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
always a typo or leftover — honoring it literally would produce all-masked
labels and zero gradient, so it is treated the same as unset.
- `cfg.roles_to_train` still governs which declared roles contribute to
loss. You can declare `user` and `assistant` boundaries and set
`roles_to_train: ["assistant"]` to have the scanner correctly identify
user spans as masking boundaries without training on their content.
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
unencodable markers), not silently at loss-compute time.

View File

@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
Using an optimized attention implementation is critical for training speed. Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). - **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`. - **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation. - **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16. - **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
*Note: You should only enable one attention backend.* See [Attention](attention.qmd) for the full list of backends and the canonical values.
### LoRA Optimizations ### LoRA Optimizations

View File

@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml ```yaml
rl: ipo rl: dpo
dpo_loss_type: ["ipo"]
``` ```
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
### ORPO ### ORPO
@@ -1145,8 +1147,7 @@ datasets:
type: ebft_strided_structured.transform type: ebft_strided_structured.transform
split: train[:1%] split: train[:1%]
flash_attention: false attn_implementation: flex_attention # Strided mode uses flex_attention
flex_attention: true # Strided mode uses flex_attention
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention use_reentrant: true # Required for flex_attention

View File

@@ -20,6 +20,8 @@ examples:
title: Arcee AFM title: Arcee AFM
# MistralAI # MistralAI
- name: mistral-medium-3_5
title: Mistral Medium 3.5
- name: ministral3/think - name: ministral3/think
title: Ministral 3 Thinking title: Ministral 3 Thinking
- name: ministral3/vision - name: ministral3/vision

View File

@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
## Limitations ## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) - Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
- May have a small performance overhead due to communication between GPUs - May have a small performance overhead due to communication between GPUs
## Example ## Example

View File

@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
Reduces attention memory from O(n^2) to O(n): Reduces attention memory from O(n^2) to O(n):
```yaml ```yaml
flash_attention: true attn_implementation: flash_attention_2
``` ```
### Step 6: Offload with DeepSpeed ### Step 6: Offload with DeepSpeed

View File

@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip: Here is an example of how to install from pip:
```bash ```bash
# Ensure you have a compatible version of Pytorch installed # Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' uv pip install --no-build-isolation 'axolotl>=0.16.1'
``` ```
2. Run one of the finetuning examples below. 2. Run one of the finetuning examples below.

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -50,8 +50,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_steps: 100 warmup_steps: 100
saves_per_epoch: 1 saves_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_steps: 100 warmup_steps: 100
saves_per_epoch: 1 saves_per_epoch: 1

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip: Here is an example of how to install from main for pip:
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]' uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
Here is an example of how to install from main for pip: Here is an example of how to install from main for pip:
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]' uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy # Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -59,8 +59,7 @@ gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
sdp_attention:
flash_optimum: flash_optimum:
gptq_groupsize: gptq_groupsize:

View File

@@ -39,8 +39,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -43,8 +43,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -73,8 +73,7 @@ early_stopping_patience: 3
resume_from_checkpoint: resume_from_checkpoint:
auto_resume_from_checkpoints: true auto_resume_from_checkpoints: true
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -40,8 +40,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -36,8 +36,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -37,8 +37,7 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -39,7 +39,6 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -39,7 +39,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -40,7 +40,6 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -38,7 +38,6 @@ tf32: true
gradient_checkpointing: gradient_checkpointing:
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -44,7 +44,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true
flash_attn_fuse_mlp: true flash_attn_fuse_mlp: true

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: false
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 0 evals_per_epoch: 0

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -36,7 +36,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -71,8 +71,7 @@ early_stopping_patience: 3
resume_from_checkpoint: resume_from_checkpoint:
auto_resume_from_checkpoints: true auto_resume_from_checkpoints: true
logging_steps: 1 logging_steps: 1
xformers_attention: true attn_implementation: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
sequence_len: 1024 sequence_len: 1024
bf16: auto bf16: auto
tf32: false tf32: false
flash_attention: true attn_implementation: flash_attention_2
special_tokens: special_tokens:
bos_token: "<|startoftext|>" bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>" eos_token: "<|endoftext|>"

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -36,12 +36,7 @@
"id": "msOCO4NRmRLa" "id": "msOCO4NRmRLa"
}, },
"outputs": [], "outputs": [],
"source": [ "source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip: Here is an example of how to install from pip:
```bash ```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min) # Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' uv pip install --no-build-isolation 'axolotl>=0.16.1'
``` ```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -26,7 +26,6 @@ lora_model_dir:
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0
@@ -51,8 +50,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attn_implementation: flash_attention_2
scaling_softmax: true # scaling_softmax: true # needs flex_attention
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true
flash_attention: true attn_implementation: flash_attention_2
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1

View File

@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true
flash_attention: true attn_implementation: flash_attention_2
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel micro_batch_size: 1 # must be 1 when using context parallel

View File

@@ -65,8 +65,7 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1 warmup_ratio: 0.1
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -46,7 +46,7 @@ lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -66,7 +66,7 @@ lora_target_linear: true
# --- Hardware --- # --- Hardware ---
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -47,8 +47,7 @@ lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
bf16: auto bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime attn_implementation: flex_attention
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
# (full-model compile conflicts with gradient checkpointing + flex_attention) # (full-model compile conflicts with gradient checkpointing + flex_attention)
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:

View File

@@ -46,7 +46,6 @@ lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
bf16: auto bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -48,7 +48,6 @@ lora_target_linear: true
bf16: auto bf16: auto
torch_dtype: bfloat16 torch_dtype: bfloat16
flash_attention: false
gradient_checkpointing: true gradient_checkpointing: true
torch_compile: true torch_compile: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:

View File

@@ -41,7 +41,6 @@ warmup_steps: 10
weight_decay: 0.01 weight_decay: 0.01
bf16: auto bf16: auto
flash_attention: false # strided EBFT uses flex_attention at runtime
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false

View File

@@ -72,7 +72,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -63,7 +63,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -68,7 +68,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto bf16: auto
flash_attention: true attn_implementation: flash_attention_2
gradient_checkpointing: true gradient_checkpointing: true
special_tokens: special_tokens:

View File

@@ -0,0 +1,93 @@
base_model: axolotl-ai-co/Falcon-E-1.2-3B-Exp-prequantized
output_dir: ./output
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: false
use_scattermoe: false
use_sonicmoe: false
use_onebitllms: true
load_in_8bit: false
load_in_4bit: false
chat_template: tokenizer_default
rl: dpo
datasets:
- path: allenai/Dolci-Think-DPO-7B
split: train
type: chatml.ultra
dataset_prepared_path: ./axolotl_dataset_cache
sequence_len: 8192
trust_remote_code: false
gradient_accumulation_steps: 4 # This can run on 4 GPUs
# Very important to enable gradient accumulation with FSDP
# https://github.com/huggingface/transformers/issues/29425
accelerator_config:
gradient_accumulation_kwargs:
sync_each_batch: True
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 1.0e-5
# adamw hyperparams
adam_beta1: 0.9
adam_beta2: 0.95
bf16: true
tf32: false
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 15.0
loss_watchdog_patience: 3
warmup_steps: 128
evals_per_epoch: 0
save_steps: 500
save_strategy: steps
weight_decay: 0.01
shuffle_merged_datasets: true
experimental_skip_move_to_device: true
fsdp_version: 2
fsdp_config:
offload_params: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
activation_checkpointing: true
# Comment to disable CP
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 1
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 1
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
special_tokens:
eos_token: <|end_of_text|>
eot_tokens:
- <|im_end|>

Some files were not shown because too many files have changed in this diff Show More