Compare commits

..

20 Commits

Author SHA1 Message Date
8693a1f61b fix Dockerfile-base-next: cuda 12.8.2, miniforge, sm_120 2026-05-13 14:37:01 +00:00
71c6a56e7a switch to HQQ quantization to bypass bitsandbytes sm_120 issue 2026-05-13 13:55:52 +00:00
38adf5cd37 add trust_remote_code, explicit bfloat16 and bnb dtype settings 2026-05-13 13:32:46 +00:00
3f29fa017b replace Capybara with SlimOrca (compatible ShareGPT format) 2026-05-13 12:58:29 +00:00
c02a76f132 fix field_messages mapping for Capybara/OpenHermes ShareGPT format 2026-05-13 12:56:03 +00:00
b9ceebfe7e fix deprecated type:sharegpt and flash_attention config keys 2026-05-13 12:52:25 +00:00
e9a3fd483f add human-like QLoRA training config for Llama 3.1 8B 2026-05-13 12:50:35 +00:00
eadd15c960 note MAX_JOBS for flash-attn compile speed 2026-05-13 04:45:21 +00:00
396ce4a9dd add miaai environment setup guide 2026-05-13 04:16:03 +00:00
Wing Lian
b7ec06b8a1 Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]
* Add optional Axolotl MoRA/ReMoRA integration

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Isolate MoRA adapter behavior in plugin

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Constrain MoRA variants to supported enum values

* Keep MoRA validation out of core config

---------

Co-authored-by: Swarm <swarm@localhost>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-12 07:19:55 -04:00
Wing Lian
e2f01de0e8 Fix Axolotl ReLoRA optimizer reset scope (#3646)
* Fix Axolotl ReLoRA optimizer reset scope
* fix: make relora reset method honor relora_prune_ratio

When relora_prune_method='reset' and relora_prune_ratio is explicitly
set, the ratio was silently ignored and replaced with the hardcoded
_FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to
ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset
and 0.9 for other methods. reset_optimizer now uses the same random
pruning path for both 'random' and 'reset'.

Also consolidate three-layer default mismatch: schema default for
relora_prune_method is now 'magnitude' (single canonical source);
dataclass defaults for both fields changed to None to eliminate the
conflicting fallback layer.

Tests updated: removed the test case that verified the old broken
behavior (reset ignoring ratio), added two cases proving reset honors
the passed ratio. E2E reset fixture now uses ratio=0.5 to make it
unambiguous that the ratio is honored.

* Fix ReLoRA uint8 pruning regression

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-09 17:52:35 -04:00
thad0ctor
5352d41d32 feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries` (#3625)
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries

Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.

What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
  strategy declares per-role start/end markers via
  `_build_role_boundaries`; the shared scanner honors
  `train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
  Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
  (previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
  (Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
  fallback) retain legacy behavior and emit a one-shot warning. Users
  can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
  token between user-end and assistant-start via `include_end=False`
  + scanner rewind.

See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs+types: address CodeRabbit nitpicks on PR #7

- builders/causal.py: add inline NOTE that multi-dataset configs reuse
  the first dataset's masking knobs (roles_to_train / train_on_eos) for
  all datasets — heterogeneous per-dataset overrides are not supported
  in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
  _match_prefix and _find_end with explicit types (Tensor, int,
  list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
  branch" list to 1-7 consecutive (previously skipped 3).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(mm-mask): address two CodeRabbit findings on PR #7

1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
   `_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
   `SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
   users hit a pydantic ValidationError at config load. Added "none" to
   the Literal and updated the description.

2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
   ctor treated it as "replace built-ins with empty" while the collator
   plumbing treated it as "unset", and both the design doc and the
   MultiModalConfig schema help text promised wholesale replacement for
   any set value. Aligned on opt-in semantics across all four surfaces —
   a non-empty list replaces built-ins wholesale; unset or `[]` falls back
   to built-ins. Rationale: honoring `[]` literally yields all-masked
   labels and zero gradient, which is almost always a typo or leftover
   rather than a deliberate user action. Users who want to disable role
   masking should unset the field or use `train_on_inputs: true`.

   Also sharpened the fallback one-shot warning for strategies without
   built-in boundaries: names the consequence ("only pad and media tokens
   are masked, every other token contributes to loss") and points users
   at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
   of "see axolotl/processing_strategies.py for how to declare
   boundaries."

Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
  role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
  now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
  block; updated the fallback-path detection paragraph to quote the new
  warning text
- tests/test_processing_strategies.py: +2 regressions
  (test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
  test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* doc cleanup

* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625

Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).

Six CodeRabbit findings addressed:

1. Scanner: non-trainable role's end marker ignored ``include_end``.
   Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
   with ``include_end=False``, intentionally re-matched as assistant-start)
   leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
   Fix: gate the non-trainable branch on ``best_match.include_end`` to
   mirror the trainable branch.

2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
   which never fires on real checkpoints (``special_tokens_map`` only
   holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
   attribute read ``getattr(tokenizer, "boi_token", None)``, matching
   what ``transformers.models.gemma3.processing_gemma3`` itself does.
   Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
   shape so the test exercises the production code path.

3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
   GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
   through to the base fallback. Both processors ship identical
   media-token markers, so register both under the shared
   ``Glm4vProcessingStrategy`` with independent try/except import blocks.
   Updated class docstring. +2 dispatcher regressions.

4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
   Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
   with unk-id guard; fall back to 262144 only if the string isn't in
   vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.

5. ``build_collator`` was called twice per ``build()`` (eval + train
   passes), producing two identical ``MM collator: ...`` INFO banners on
   startup. Gate the log on ``is_eval=False`` so only the training pass
   emits it.

6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
   always returned ``None``; the dispatcher already handles missing
   ``mistral_common`` via lazy import + ``try/except``). Added
   ``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
   — a focused scanner-level lock-in for finding #1, independent of any
   specific VLM strategy.

Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings

- Scanner perf: convert labels[i] to a Python list once per row so
  _match_prefix / _find_end operate on list slices instead of
  re-materializing Tensor slices via .tolist() on every probe. Cuts
  O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
  under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
  bare-minimum "why" notes matching the repo's existing style.

68/68 tests, 8/8 pre-commit hooks still pass.
2026-05-05 11:25:39 -04:00
VED
c15f6cffe2 fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)
* memory clean up for fsdp full state dict

* Update src/axolotl/monkeypatch/accelerate/fsdp2.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2026-05-05 11:22:35 -04:00
Wing Lian
e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
Younes B
6136ae627b Fix: add bitnet config (#3636)
* add bitnet config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 12:30:56 -04:00
Younes B
e662972a29 Feat: Add bitnet integration (#3634)
* add bitnet

* switch to uv

* chore: liint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-30 11:25:02 -04:00
NanoCode012
ebbd7fa847 feat: Add Mistral Medium 3.5 (#3633)
* fix: clarify incompat

* fix: transformers api change upstream

* fix: add pre prop

* feat: add examples

* chore: cleanup

* chore: update readme
2026-04-29 22:46:51 +07:00
Wing Lian
ac77da96da use smaller pretrained models for ci (#3620) [skip ci]
* use smaller pretrained models for ci

* more steps for loss check

* fix tests

* more train steps

* fix losses
2026-04-27 13:22:56 -04:00
NanoCode012
798c8fba89 chore: update docker docs (#3623)
Some checks failed
Publish Docs / build-deploy (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / pre-commit (push) Has been cancelled
Tests Nightly against upstream main / Prefetch S3 once to prime the CDN cache (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, 1, 3.11, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, true, 1, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 130, 13.0.0, true, 1, 3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-multigpu-tests (<nil>, 128, 12.8.1, true, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (<nil>, 130, 13.0.0, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (fbgemm-gpu, 128, 12.8.1, 2, 3.11, 2.10.0) (push) Has been cancelled
Pre-commit auto-update / auto-update (push) Has been cancelled
2026-04-24 16:03:21 +07:00
NanoCode012
17fc747f99 fix: docker build failing (#3622)
* fix: uv leftover docs

* fix: docker build failing

* chore: doc

* fix: remove old pytorch build

* fix: stop recommend flash-attn optional, let transformers pull

* fix: remove ring flash attention from image

* fix: quotes [skip ci]

* chore: naming [skip ci]
2026-04-24 14:23:09 +07:00
100 changed files with 4910 additions and 981 deletions

View File

@@ -31,10 +31,11 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
# Install axolotl + dev and test dependencies from lockfile
# Install axolotl + dev and test dependencies
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install
# test

View File

@@ -30,14 +30,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -168,14 +160,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -18,12 +18,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -180,12 +174,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -72,7 +72,7 @@ jobs:
exclude:
- python_version: "3.14"
pytorch_version: "2.9.1"
timeout-minutes: 20
timeout-minutes: 25
steps:
- name: cleanup node

View File

@@ -1,142 +0,0 @@
# `attn-implementation-refactor` branch review
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
Scope: 16 files, +682 / 96.
## 1. What the branch is trying to do
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
## 2. Target design
**`cfg.attn_implementation` is the single source of truth on the validated config.**
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
## 3. Gaps and holes
### A. Legacy flags are still parallel state, not input-only
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
### B. Validators that still only check the legacy flag
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
### C. Non-enum strings fall through the capability tables
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
### D. Silent behaviour change for xformers
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
### E. Capability flags are writable Pydantic fields, not computed
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
### F. Validator ordering (not covered by tests)
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
### G. Remaining legacy reads in `src/`
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
### H. fp8 operational risk
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
### I. Test coverage gaps
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
- Verify capability flags can't be overridden from YAML (item 11).
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
- Cover hub-kernel + legacy flag combo (item 9).
- Cover `flash_attention_2` canonicalization (item 7).
## 4. Fix plan
Four phases, each commit-sized. Phases 12 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
### Phase 1 — Lock down the data model
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
- `flash_attention → flash_attention_2`
- `sdp_attention → sdpa`
- `flex_attention → flex_attention`
- `xformers_attention → xformers`
- `sage_attention → sage`
- `s2_attention → s2`
- `eager_attention → eager`
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
- No-dtype-cast: `{eager, sdpa}`.
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
### Phase 2 — Fix the validators
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
### Phase 3 — Migrate remaining consumers
10. `lm_eval/cli.py:120``flash_attention=cfg.attn_uses_flash_lib`.
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
13. `tests/e2e/multigpu/test_llama.py:524-526``cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
### Phase 4 — Tests + docs
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
## 5. Ordering & risk
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.

View File

@@ -29,6 +29,9 @@
## 🎉 Latest Updates
- 2026/04:
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
- 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).

83
SETUP_MIAAI.md Normal file
View File

@@ -0,0 +1,83 @@
# Axolotl Setup — miaai (RTX 5080, CUDA 13.2)
## System Info
- GPU: NVIDIA RTX 5080 (16GB VRAM)
- Driver: 580.126.09 — max CUDA 13.0 (nvcc from conda resolves to 13.2)
- OS: Ubuntu (Python 3.13 system — do NOT use system Python for ML)
- Axolotl branch: `activeblue/main`
## One-time Setup
### 1. Install Miniconda
```bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
bash miniconda.sh -b -p /opt/miniconda3
/opt/miniconda3/bin/conda init bash
source ~/.bashrc
```
### 2. Create Python 3.11 environment
```bash
conda create -n axolotl python=3.11 -y
conda activate axolotl
```
### 3. Clone and sync repo with upstream
```bash
git clone https://git.activeblue.net/tocmo0nlord/axolotl.git
cd axolotl
git remote add upstream https://github.com/axolotl-ai-cloud/axolotl.git
git fetch upstream
git rebase upstream/main # keeps activeblue patches on top
git push origin activeblue/main --force-with-lease
```
### 4. Install CUDA toolkit (needed to compile flash-attn)
```bash
conda install -y -c "nvidia/label/cuda-12.8.0" cuda-toolkit
export CUDA_HOME=$CONDA_PREFIX
export PATH=$CUDA_HOME/bin:$PATH
```
### 5. Install PyTorch — use cu132 (matches nvcc from conda)
> NOTE: torchaudio has no cu132 wheel — skip it, not needed for LLM training
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu132
python -c "import torch; print('CUDA:', torch.version.cuda); print('GPU:', torch.cuda.get_device_name(0))"
```
### 6. Install Axolotl
```bash
pip install -e "."
```
> **flash-attn compiles CUDA kernels from source — takes 1525 min on 10 cores of i7-14700K.**
> Always set `MAX_JOBS` to the number of available CPU cores to parallelize and speed up compilation:
```bash
MAX_JOBS=10 pip install flash-attn --no-build-isolation
```
## Every Session (after first-time setup)
```bash
export PATH="/opt/miniconda3/bin:$PATH"
conda activate axolotl
export CUDA_HOME=$CONDA_PREFIX
export PATH=$CUDA_HOME/bin:$PATH
cd /home/tocmo0nlord/axolotl
```
## Run Training
```bash
axolotl train human_chat_qlora.yml
```
## Common Pitfalls Encountered
| Problem | Cause | Fix |
|---|---|---|
| `externally-managed-environment` | System Python 3.13 blocks pip | Use conda env, never system pip |
| `No module named torch` (flash-attn) | pip builds in isolated env | Use `--no-build-isolation` |
| `CUDA_HOME not set` | CUDA toolkit not installed | `conda install cuda-toolkit` from nvidia channel |
| `CUDA version mismatch 13.2 vs 12.8` | Conda nvcc is 13.2, torch was cu128 | Reinstall torch with `--index-url .../cu132` |
| `torchaudio` not found for cu132 | No cu132 wheel exists | Skip torchaudio — not needed |
| `src refspec main does not match` | Fork default branch is `activeblue/main` | `git push origin activeblue/main` |
| flash-attn compile is slow | Single-threaded by default | Set `MAX_JOBS=<cpu_count>` before pip install |

View File

@@ -1 +1 @@
0.16.0.dev0
0.16.2.dev0

View File

@@ -311,6 +311,7 @@ website:
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- docs/1_58bit_finetuning.qmd
- docs/optimizations.qmd
- section: "Core Concepts"

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -1,16 +1,15 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG CUDA_VERSION="12.8.2"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
FROM nvidia/cuda:12.8.2-devel-ubuntu22.04 AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/root/miniforge3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 12.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
@@ -18,13 +17,13 @@ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
&& bash Miniforge3-Linux-x86_64.sh -b \
&& rm -f Miniforge3-Linux-x86_64.sh \
&& /root/miniforge3/bin/conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/root/miniforge3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -0,0 +1,70 @@
---
title: "1.58-bit Finetuning"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
::: {.callout-note}
LoRA is not supported for BitNet models
:::
## Installation
Install the `onebitllms` package before using this feature:
```bash
uv pip install onebitllms
```
Or from source:
```bash
uv pip install git+https://github.com/tiiuae/onebitllms
```
## Supported models
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
```bash
tiiuae/Falcon-E-3B-Base-prequantized
tiiuae/Falcon-E-1B-Base-prequantized
```
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
## Configuration
To enable 1.58-bit finetuning, set the following in your configuration file:
```yaml
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
use_onebitllms: true
```
::: {.callout-note}
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
:::
## Considerations after training
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
```bash
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
```
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
## Example Configuration
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.

View File

@@ -97,8 +97,10 @@ python setup.py install
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
FA4 is still a pre-release on PyPI, so `--pre` is required:
```bash
pip install flash-attn-4
pip install --pre flash-attn-4
```
Or from source:

View File

@@ -77,8 +77,9 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
#### Remote Hosts
@@ -218,8 +219,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
You will now be in the container. Next, install Axolotl with dev dependencies:
```bash
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
### Attach To Container

View File

@@ -10,13 +10,16 @@ This section describes the different Docker images that are released by AxolotlA
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
:::
### Switch to the `-uv` images
::: {.callout-tip}
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
:::
## Base
@@ -85,7 +88,7 @@ Tags examples:
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20260315-py3.11-cu128-2.9.1`
- `0.12.0`
- `0.16.1`
## Cloud

View File

@@ -57,7 +57,7 @@ description: Frequently asked questions
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.9.0
- PyTorch ≥2.9.1
## Installation {#sec-installation}
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash}
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv venv
source .venv/bin/activate
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
uv pip install --no-build-isolation axolotl[deepspeed]
```
### Edge/Development Build {#sec-edge-build}
@@ -49,12 +49,11 @@ For the latest features between releases:
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed
uv venv
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]'
```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker}
```{.bash}
@@ -132,11 +131,11 @@ source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv venv
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
uv pip install --no-build-isolation axolotl[deepspeed]
```
## Using pip (Alternative) {#sec-pip}
@@ -151,13 +150,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
pip3 install --no-build-isolation axolotl[deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[deepspeed]'
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -0,0 +1,84 @@
# Multimodal assistant-only loss masking
## Correct placement
```yaml
# Top-level: only train_on_inputs lives here.
train_on_inputs: false
datasets:
- path: data/train.jsonl
type: chat_template
roles_to_train: # per-dataset — this is what the MM scanner reads
- assistant
train_on_eos: turn # per-dataset — same
test_datasets:
- path: data/val.jsonl
type: chat_template
split: train
roles_to_train:
- assistant
train_on_eos: turn
```
## How to verify at runtime
`build_collator` logs the resolved knobs at INFO:
```text
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
```
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
scanner — check that they are under `datasets[0]`, not at the root.
Each verified strategy additionally logs its resolved boundary token ids at
strategy init (e.g. `<|turn>model``[105, 4368]`, `<turn|>``[106]` for
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
pad and media tokens are masked" one-shot warning instead, it is on the
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
(below) to activate masking. The strategies currently on this path are
listed in the audit table above under `fallback + warn`.
## Config-based override: `cfg.role_boundaries`
For the "unverified" strategies above, or for custom chat templates that
don't match a built-in strategy's markers, users can declare role boundaries
directly in YAML without subclassing:
```yaml
role_boundaries:
- role: assistant
start: "<|turn>model"
end: "<turn|>"
- role: user
start: "<|turn>user"
end: "<turn|>"
# Optional keys:
# include_start: false # default False
# include_end: true # default True, respects cfg.train_on_eos
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
# end: null # span runs to end of sequence
```
Semantics:
- `start` and `end` are literal strings; axolotl encodes them at strategy
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
resolved token-id sequences at INFO level.
- The special value `end: eos_token` is the portable way to express
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
the strategy's built-in declarations wholesale (partial overlays are
intentionally unsupported — they're hard to reason about at review time).
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
always a typo or leftover — honoring it literally would produce all-masked
labels and zero gradient, so it is treated the same as unset.
- `cfg.roles_to_train` still governs which declared roles contribute to
loss. You can declare `user` and `assistant` boundaries and set
`roles_to_train: ["assistant"]` to have the scanner correctly identify
user spans as masking boundaries without training on their content.
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
unencodable markers), not silently at loss-compute time.

View File

@@ -20,6 +20,8 @@ examples:
title: Arcee AFM
# MistralAI
- name: mistral-medium-3_5
title: Mistral Medium 3.5
- name: ministral3/think
title: Ministral 3 Thinking
- name: ministral3/vision

View File

@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Run one of the finetuning examples below.

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -36,12 +36,7 @@
"id": "msOCO4NRmRLa"
},
"outputs": [],
"source": [
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
},
{
"cell_type": "markdown",

View File

@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -26,7 +26,6 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
@@ -52,7 +51,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
scaling_softmax: true
# scaling_softmax: true # needs flex_attention
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -0,0 +1,93 @@
base_model: axolotl-ai-co/Falcon-E-1.2-3B-Exp-prequantized
output_dir: ./output
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: false
use_scattermoe: false
use_sonicmoe: false
use_onebitllms: true
load_in_8bit: false
load_in_4bit: false
chat_template: tokenizer_default
rl: dpo
datasets:
- path: allenai/Dolci-Think-DPO-7B
split: train
type: chatml.ultra
dataset_prepared_path: ./axolotl_dataset_cache
sequence_len: 8192
trust_remote_code: false
gradient_accumulation_steps: 4 # This can run on 4 GPUs
# Very important to enable gradient accumulation with FSDP
# https://github.com/huggingface/transformers/issues/29425
accelerator_config:
gradient_accumulation_kwargs:
sync_each_batch: True
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 1.0e-5
# adamw hyperparams
adam_beta1: 0.9
adam_beta2: 0.95
bf16: true
tf32: false
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 15.0
loss_watchdog_patience: 3
warmup_steps: 128
evals_per_epoch: 0
save_steps: 500
save_strategy: steps
weight_decay: 0.01
shuffle_merged_datasets: true
experimental_skip_move_to_device: true
fsdp_version: 2
fsdp_config:
offload_params: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
activation_checkpointing: true
# Comment to disable CP
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 1
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 1
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
special_tokens:
eos_token: <|end_of_text|>
eot_tokens:
- <|im_end|>

View File

@@ -0,0 +1,100 @@
base_model: tiiuae/Falcon-E-3B-Base-prequantized
output_dir: ./output
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: false
use_scattermoe: false
use_sonicmoe: false
use_onebitllms: true
load_in_8bit: false
load_in_4bit: false
chat_template: tokenizer_default
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: ./axolotl_dataset_cache
sequence_len: 32768
trust_remote_code: false
gradient_accumulation_steps: 4 # This can run on 4 GPUs
# Very important to enable gradient accumulation with FSDP
# https://github.com/huggingface/transformers/issues/29425
accelerator_config:
gradient_accumulation_kwargs:
sync_each_batch: True
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5.0e-4
# adamw hyperparams
adam_beta1: 0.9
adam_beta2: 0.95
bf16: true
tf32: false
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 15.0
loss_watchdog_patience: 3
warmup_steps: 128
evals_per_epoch: 0
save_steps: 500
save_strategy: steps
weight_decay: 0.01
sample_packing: true
pad_to_sequence_len: true
shuffle_merged_datasets: true
experimental_skip_move_to_device: true
fsdp_version: 2
fsdp_config:
offload_params: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
activation_checkpointing: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
# Comment to disable CP
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 1
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 1
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
special_tokens:
eos_token: <|end_of_text|>
eot_tokens:
- <|im_end|>

View File

@@ -9,8 +9,8 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -13,8 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -9,11 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,8 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -59,7 +59,7 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
scaling_softmax: true
# scaling_softmax: true # needs flex_attention
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -0,0 +1,78 @@
# Finetune Mistral Medium 3.5 with Axolotl
[Mistral Medium 3.5](https://huggingface.co/mistralai/Mistral-Medium-3.5-128B) is a 128B parameter dense multimodal model from MistralAI that unifies instruct, reasoning, and agentic capabilities into a single model.
It shares the `mistral3` architecture (dense, YaRN RoPE, 256k context) with Ministral 3 and supports the same `reasoning_effort` toggle as Mistral Small 4.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. (Text config only) Install [Flash Attention 4](https://docs.axolotl.ai/docs/attention.html#flash-attention-4) on Hopper/Blackwell.
4. Run one of the example configs:
```bash
# text-only
axolotl train examples/mistral-medium-3_5/qlora-text.yml # ~83.1 GiB
# text + vision
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
axolotl train examples/mistral-medium-3_5/qlora-vision.yml # ~80.3 GiB
```
Note: vision training does not currently work with Flash Attention 4.
## Reasoning Effort
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
- `"none"` — instruct mode (default)
- `"high"` — reasoning mode with explicit thinking steps
Pass it via `chat_template_kwargs` under your dataset config:
```yaml
datasets:
- path: your/dataset
type: chat_template
chat_template_kwargs:
reasoning_effort: high
```
## Thinking Support
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
```yaml
datasets:
- path: your/thinking-dataset
type: chat_template
message_property_mappings:
role: role
content: content
thinking: thinking
chat_template_kwargs:
reasoning_effort: high
```
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
## Tips
- For smaller experiments on the same architecture, see [`examples/ministral3`](../ministral3/README.md) (Ministral 3, 3B).
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Related Resources
- [Mistral Medium 3.5 Blog](https://mistral.ai/news/vibe-remote-agents-mistral-medium-3-5)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: axolotl-ai-co/Mistral-Medium-3.5-128B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir:
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
# prevents targeting vision layers
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
max_steps: 10
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,61 @@
base_model: axolotl-ai-co/Mistral-Medium-3.5-128B-BF16
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir:
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
max_steps: 10
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl>=0.16.1'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,7 +13,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Install an extra dependency:

View File

@@ -11,8 +11,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Please install the below.

92
human_chat_qlora.yml Normal file
View File

@@ -0,0 +1,92 @@
# Llama 3.1 8B — Human-like LoRA fine-tune (HQQ quantization)
#
# Goal: natural, warm conversation; never corrects user errors; direct responses
# Hardware: single RTX 5080 (16 GB VRAM)
# Method: LoRA on HQQ 4-bit quantized base (bypasses bitsandbytes — RTX 5080 compatible)
#
# Prerequisites:
# pip install -e '.[flash-attn]' (inside your axolotl repo)
# huggingface-cli login (meta-llama is a gated model)
#
# Run:
# axolotl train human_chat_qlora.yml
# axolotl merge-lora human_chat_qlora.yml # (optional) merge adapter into base
base_model: meta-llama/Meta-Llama-3.1-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
# HQQ quantization — no bitsandbytes required, works on RTX 5080 (sm_120)
quant_method: hqq
strict: false
trust_remote_code: true
torch_dtype: bfloat16
# --- System prompt baked into every conversation ---
# This is the primary lever for "no error correcting, more human-like"
chat_template: llama3
default_system_message: >-
You are a direct, warm, and genuinely helpful assistant.
Respond to the user's intent naturally — never comment on typos, grammar,
or phrasing issues in their message. Just understand what they mean and give
a clear, useful, conversational answer as if talking to a knowledgeable friend.
# --- Datasets ---
# Both use ShareGPT format: conversations field, from/value keys
datasets:
- path: Open-Orca/SlimOrca
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
split: "train[:3%]"
- path: teknium/OpenHermes-2.5
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
split: "train[:5%]"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/llama31-8b-humanchat
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
# --- LoRA adapter (on top of HQQ quantized base) ---
adapter: lora
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
# --- Training hyperparameters ---
# Effective batch = micro_batch_size x gradient_accumulation = 2 x 4 = 8
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 2
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.05
weight_decay: 0.1
train_on_inputs: false
group_by_length: false
bf16: auto
tf32: false
# --- Memory & speed ---
gradient_checkpointing: true
attn_implementation: flash_attention_2
# --- Logging & checkpointing ---
logging_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
special_tokens:
pad_token: "<|eot_id|>"

View File

@@ -12,7 +12,7 @@ requires-python = ">=3.10"
dependencies = [
# Core ML stack
"torch>=2.6.0",
"torch>=2.9.1",
"packaging==26.0",
"huggingface_hub>=1.1.7",
"peft>=0.19.1,<0.20.0",
@@ -79,7 +79,7 @@ dependencies = [
# Platform-specific (Linux only)
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
"triton>=3.4.0 ; sys_platform != 'darwin'",
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",

View File

@@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.relora and self.cfg.jagged_restart_steps:
if self.cfg.relora_prune_ratio:
if self.cfg.relora_prune_ratio is not None:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.relora_prune_method:
training_arguments_kwargs["relora_prune_method"] = (
self.cfg.relora_prune_method
)
if self.cfg.jagged_restart_steps:
training_arguments_kwargs["jagged_restart_steps"] = (
@@ -515,12 +519,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
# NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets;
# heterogeneous per-dataset overrides are not supported in the MM path today.
ds_entries = self.cfg.datasets or []
ds_cfg = ds_entries[0] if ds_entries else None
def _ds_get(cfg_obj, key):
# Handle DictDefault / dict / pydantic uniformly:
# dict-style .get first, then attribute access.
if cfg_obj is None:
return None
if hasattr(cfg_obj, "get"):
try:
return cfg_obj.get(key)
except (AttributeError, KeyError, TypeError):
pass
return getattr(cfg_obj, key, None)
roles_to_train = _ds_get(ds_cfg, "roles_to_train")
train_on_eos = _ds_get(ds_cfg, "train_on_eos")
# cfg.role_boundaries replaces the strategy's built-in markers.
role_boundaries_override = None
if self.cfg.role_boundaries:
role_boundaries_override = list(self.cfg.role_boundaries)
# build() calls build_collator twice (eval + train); log once.
if not is_eval:
LOG.info(
"MM collator: train_on_inputs=%s roles_to_train=%s "
"train_on_eos=%s role_boundaries_override=%s",
bool(self.cfg.train_on_inputs),
roles_to_train,
train_on_eos,
"set" if role_boundaries_override else "none",
)
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
train_on_inputs=bool(self.cfg.train_on_inputs),
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
role_boundaries_override=role_boundaries_override,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import gc
import json
import math
import os
@@ -800,7 +801,14 @@ class AxolotlTrainer(
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)
return super()._save_checkpoint(model, trial, **kwargs)
result = super()._save_checkpoint(model, trial, **kwargs)
# Reclaim VRAM held by the FSDP full-state-dict gather.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):

View File

@@ -83,13 +83,18 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "The number of processes to use for data processing"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
default=None,
metadata={
"help": (
"prune ratio for optimizer state pruning; "
"defaults to 0.999 for reset method, 0.9 for others"
)
},
)
relora_prune_method: Optional[str] = field(
default=None,
metadata={"help": "optimizer state pruning method: magnitude | random | reset"},
)
jagged_restart_steps: Optional[int] = field(
default=None,

View File

@@ -23,9 +23,10 @@ from __future__ import annotations
import collections
import importlib
import traceback
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from peft import PeftConfig, PeftMixedModel, PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -41,6 +42,15 @@ if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
@dataclass(frozen=True)
class AdapterCapabilities:
"""Capabilities for an adapter contributed by a plugin."""
name: str
lora_like: bool = False
relora: bool = False
class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods.
@@ -91,6 +101,26 @@ class BasePlugin:
Returns a dataclass model for the plugin's training arguments.
"""
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
"""Returns adapter capabilities contributed by the plugin."""
return []
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters."""
return {}
def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Optionally load a plugin adapter instead of the generic loader."""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -414,6 +444,58 @@ class PluginManager:
training_args.append(training_args_from_plugin)
return training_args
def adapter_capabilities(self) -> dict[str, AdapterCapabilities]:
"""Returns adapter capabilities by adapter name."""
capabilities = {}
for plugin in self.plugins.values():
for adapter_capability in plugin.get_adapter_capabilities():
capabilities[adapter_capability.name] = adapter_capability
return capabilities
def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None:
"""Returns capabilities for a registered plugin adapter."""
return self.adapter_capabilities().get(adapter)
def supports_adapter(self, adapter: str) -> bool:
"""Returns whether a plugin has registered the adapter name."""
return adapter in self.adapter_capabilities()
def adapter_supports_relora(self, adapter: str) -> bool:
"""Returns whether a plugin adapter supports ReLoRA restart semantics."""
capability = self.get_adapter_capability(adapter)
return bool(capability and capability.relora)
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra LoraConfig kwargs from plugins for the configured adapter."""
lora_config_kwargs = {}
for plugin in self.plugins.values():
plugin_kwargs = plugin.get_lora_config_kwargs(cfg)
if plugin_kwargs:
lora_config_kwargs.update(plugin_kwargs)
return lora_config_kwargs
def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Returns the first plugin adapter loader result, if any."""
for plugin in self.plugins.values():
loaded = plugin.load_adapter(
model,
cfg,
inference=inference,
config_only=config_only,
)
if loaded is not None:
return loaded
return None
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:

View File

@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # torch>=2.6.0, recommended to reduce vram
torch_compile: True # recommended to reduce vram
datasets:
- path: ...

View File

@@ -0,0 +1,6 @@
"""MoRA / ReMoRA integration for Axolotl."""
from .args import MoraArgs, MoraConfig, MoraType
from .plugin import MoraPlugin
__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"]

View File

@@ -0,0 +1,66 @@
"""Config args for MoRA / ReMoRA."""
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, Field, model_validator
class MoraType(str, Enum):
"""MoRA variants supported by the reference implementation."""
SHARING = "sharing"
ROPE = "rope"
@property
def peft_value(self) -> int:
return {
MoraType.SHARING: 1,
MoraType.ROPE: 6,
}[self]
class MoraConfig(BaseModel):
"""Nested MoRA configuration available under the `mora` key."""
use_mora: bool = Field(
default=True,
description=(
"Enable MoRA adapter construction. Requires a PEFT build with MoRA "
"support (for example, the MoRA fork)."
),
)
mora_type: MoraType = Field(
default=MoraType.ROPE,
description=(
"MoRA variant selector. Supported values are `sharing` for type 1 "
"and `rope` for type 6. Numeric values 1 and 6 are accepted for "
"backwards compatibility."
),
)
@model_validator(mode="before")
@classmethod
def normalize_mora_type(cls, data):
if not isinstance(data, dict) or "mora_type" not in data:
return data
data = data.copy()
mora_type = data["mora_type"]
if mora_type == 1:
data["mora_type"] = MoraType.SHARING
elif mora_type == 6:
data["mora_type"] = MoraType.ROPE
return data
class MoraArgs(BaseModel):
"""Plugin entry that exposes the nested `mora` block to the core config."""
mora: MoraConfig = Field(
default_factory=MoraConfig,
description=(
"MoRA / ReMoRA training configuration. Register the "
"`axolotl.integrations.mora.MoraPlugin` plugin to enable this block."
),
)

View File

@@ -0,0 +1,97 @@
"""MoRA / ReMoRA plugin for Axolotl."""
import inspect
from peft import LoraConfig, PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import AdapterCapabilities, BasePlugin
from axolotl.integrations.mora.args import MoraType
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _peft_supports_mora() -> bool:
try:
params = inspect.signature(LoraConfig).parameters
except (TypeError, ValueError):
return False
return "use_mora" in params and "mora_type" in params
def _mora_type_peft_value(mora_type: MoraType | str | int) -> int:
if isinstance(mora_type, MoraType):
return mora_type.peft_value
if mora_type == 1 or mora_type == MoraType.SHARING.value:
return MoraType.SHARING.peft_value
if mora_type == 6 or mora_type == MoraType.ROPE.value:
return MoraType.ROPE.peft_value
raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6")
def _mora_type_label(mora_type: MoraType | str | int) -> str:
if isinstance(mora_type, MoraType):
return mora_type.value
if mora_type == 1:
return MoraType.SHARING.value
if mora_type == 6:
return MoraType.ROPE.value
return str(mora_type)
class MoraPlugin(BasePlugin):
"""Plugin that exposes MoRA-specific config and validates runtime support."""
def get_input_args(self) -> str:
return "axolotl.integrations.mora.MoraArgs"
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="mora", lora_like=True, relora=True)]
def _validate_mora_config(self, cfg: DictDefault):
mora_cfg = getattr(cfg, "mora", None)
if mora_cfg is None:
raise ValueError("adapter: mora requires a nested mora configuration block")
if not getattr(mora_cfg, "use_mora", False):
raise ValueError("mora.use_mora must be true when adapter: mora is set")
if cfg.load_in_4bit or cfg.load_in_8bit:
raise ValueError(
"adapter: mora currently requires a full-precision base model. "
"Use adapter: lora or qlora for quantized training."
)
if cfg.gptq:
raise ValueError(
"adapter: mora is not compatible with GPTQ quantized base models."
)
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "mora":
return {}
self._validate_mora_config(cfg)
if not _peft_supports_mora():
raise ImportError(
"adapter: mora requires a PEFT build with MoRA support "
"(LoraConfig(use_mora=..., mora_type=...)). "
"Install the MoRA fork or another PEFT distribution that exposes "
"those fields."
)
mora_cfg = cfg.mora
return {
"use_mora": mora_cfg.use_mora,
"mora_type": _mora_type_peft_value(mora_cfg.mora_type),
}
def pre_model_load(self, cfg: DictDefault):
if cfg.adapter != "mora":
return
LOG.info("MoRA plugin enabled for adapter: mora")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
if cfg.adapter == "mora" and getattr(cfg, "mora", None):
LOG.debug(
"Loaded MoRA model with mora_type=%s, relora=%s",
_mora_type_label(cfg.mora.mora_type),
cfg.relora,
)

View File

@@ -19,12 +19,14 @@ from peft import (
)
from transformers import PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_quantized_meta_for_peft(model: torch.nn.Module):
@@ -124,6 +126,76 @@ def _patch_peft_clippable_linear():
LoraModel._axolotl_clippable_patched = True
def _get_peft_task_type(model: PreTrainedModel) -> TaskType:
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
return TaskType.SEQ_CLS
if "TokenClassification" in model_cls:
return TaskType.TOKEN_CLS
return TaskType.CAUSAL_LM
def _build_lora_config_kwargs(cfg: DictDefault) -> dict[str, Any]:
lora_config_kwargs: dict[str, Any] = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
return lora_config_kwargs
def _build_peft_lora_config(
model: PreTrainedModel,
cfg: DictDefault,
) -> PeftConfig:
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = _build_lora_config_kwargs(cfg)
lora_config_kwargs.update(PLUGIN_MANAGER.get_lora_config_kwargs(cfg))
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=_get_peft_task_type(model),
**lora_config_kwargs,
)
return lora_config
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
"""Check whether PEFT will auto-populate target_parameters for this model.
@@ -226,62 +298,7 @@ def load_lora(
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
# Determine the correct PEFT task type
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
task_type = TaskType.SEQ_CLS
elif "TokenClassification" in model_cls:
task_type = TaskType.TOKEN_CLS
else:
task_type = TaskType.CAUSAL_LM
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,
)
lora_config = _build_peft_lora_config(model, cfg)
if config_only:
return None, lora_config
@@ -315,7 +332,7 @@ def load_lora(
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
LOG.debug("Loading pretrained PEFT adapter")
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
@@ -364,30 +381,60 @@ def load_adapter(
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
if adapter == "llama-adapter":
if config_only:
_, lora_config = load_llama_adapter(model, cfg, config_only=True)
return None, lora_config
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config
raise NotImplementedError(f"{adapter} PEFT adapter not available")
plugin_loaded = PLUGIN_MANAGER.load_adapter(
model,
cfg,
inference=inference,
config_only=config_only,
)
if plugin_loaded is not None:
return plugin_loaded
adapter_capability = PLUGIN_MANAGER.get_adapter_capability(adapter)
if adapter_capability and adapter_capability.lora_like:
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
registered = sorted(PLUGIN_MANAGER.adapter_capabilities())
registered_msg = ", ".join(registered) if registered else "none"
raise NotImplementedError(
f"Adapter '{adapter}' is not built in and was not registered by a plugin "
f"with loader support. Registered plugin adapters: {registered_msg}"
)
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:
model: PreTrainedModel, cfg: DictDefault, config_only: bool = False
) -> tuple[PeftModel | PeftMixedModel | None, PeftConfig]:
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if config_only:
return None, peft_config
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - llama_adapter")
peft_model = PeftModel.from_pretrained(

View File

@@ -39,7 +39,7 @@ from transformers.integrations.deepspeed import (
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.loaders.adapter import load_adapter, load_lora
from axolotl.loaders.adapter import load_adapter
from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.patch_manager import PatchManager
from axolotl.loaders.utils import (
@@ -386,8 +386,12 @@ class ModelLoader:
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(
self.model, self.cfg, inference=False, config_only=True
_, lora_config = load_adapter(
self.model,
self.cfg,
self.cfg.adapter,
inference=False,
config_only=True,
)
else:
self.model, lora_config = load_adapter(
@@ -628,13 +632,7 @@ class ModelLoader:
)
def _set_attention_config(self):
# s2 and fp8 need a different HF backend at load time than their
# canonical name: s2 patches FA2 internals, so load under FA2; fp8
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
# Every other canonical name (and hub-kernel paths) is passed through
# verbatim — xformers/sage/flash_attention_* are registered under their
# own names in ALL_ATTENTION_FUNCTIONS before model load. gemma4_hybrid
# is already pinned to flash_attention_2 by normalize_attn_implementation.
# s2 patches FA2 internals (load as FA2); fp8 replaces sdpa post-load (load as sdpa).
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
if self.cfg.attn_implementation:
hf_impl = _LOAD_TIME_OVERRIDE.get(
@@ -826,6 +824,17 @@ class ModelLoader:
else:
self.model = self._load_model_from_pretrained(model_loader_class)
if self.cfg.use_onebitllms:
try:
from onebitllms import replace_linear_with_bitnet_linear
except ImportError as exc:
raise ImportError(
"The 'onebitllms' package is required for use_onebitllms. "
"Install it with: `uv pip install onebitllms`"
) from exc
self.model = replace_linear_with_bitnet_linear(self.model)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True

View File

@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy
import functools
import gc
import os
import sys
@@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True):
state_dict = {}
sharded_state_dict = model.state_dict()
is_rank_zero = torch.distributed.get_rank() == 0
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
@@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True):
if isinstance(param, DTensor):
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
if is_rank_zero:
state_dict[param_name] = param.cpu()
# Drop the GPU-resident gathered tensor before the next iteration
# allocates the next one; otherwise the caching allocator holds
# both reservations and we accumulate ~model-size of VRAM.
del param
torch.distributed.barrier()
# Release the sharded view and force the allocator to give back the
# gather buffers.
del sharded_state_dict
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import (
FullStateDictConfig,

View File

@@ -6,9 +6,8 @@ import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Literal, Union
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
@@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
import bitsandbytes as bnb
except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths
bnb = None
@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
"""Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place."""
tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
@@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio):
tensor.mul_(mask.to(dtype=tensor.dtype))
@torch.no_grad()
def random_pruning_(tensor, prune_ratio):
"""Zero a random ``prune_ratio`` fraction of values, in place."""
mask = (
torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device)
> prune_ratio
)
tensor.mul_(mask.to(dtype=tensor.dtype))
# 0.999 mirrors the reference implementation. True zeroing breaks
# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's
# peft_pretraining/training_utils.py for the original note on this.
_FULL_RESET_RATIO = 0.999
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
reset_params: List[torch.nn.Parameter],
optimizer_state_keys: List[str],
optimizer_magnitude_pruning: float = 0.9,
prune_method: Literal["magnitude", "random", "reset"] = "magnitude",
prune_ratio: float = 0.9,
):
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
"""Prune optimizer state for ``reset_params`` only."""
if prune_method == "magnitude":
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
elif prune_method in ("random", "reset"):
# "reset" is random pruning at a near-full ratio; the caller is responsible
# for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin).
pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio)
else:
raise ValueError(
f"Unknown prune_method {prune_method!r}; expected one of "
"'magnitude', 'random', 'reset'"
)
n_zeros = 0
n_total = 0
@@ -56,22 +89,22 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
for param in reset_params:
state = optimizer_state.get(param, {})
if not state:
continue
for key in optimizer_state_keys:
value = state.get(key)
if value is None or not torch.is_tensor(value):
continue
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
continue
if torch.is_tensor(value):
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc
raise
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
@@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.jagged_restart_steps
self.jagged_restart_steps = cfg.jagged_restart_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
self.resume_from_checkpoint = cfg.resume_from_checkpoint
self.prune_method = cfg.relora_prune_method or "magnitude"
if not os.path.exists(self.last_full_model):
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
@@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback):
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
@@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback):
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [
n
p
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]
@@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
# When relora_prune_ratio is not set, use _FULL_RESET_RATIO for
# "reset" (paper-style near-full reset) and 0.9 for other methods.
prune_ratio = args.relora_prune_ratio
if prune_ratio is None:
prune_ratio = (
_FULL_RESET_RATIO if self.prune_method == "reset" else 0.9
)
reset_optimizer(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
optimizer_magnitude_pruning=args.relora_prune_ratio,
prune_method=self.prune_method,
prune_ratio=prune_ratio,
)
if self.quantized:
@@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback):
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
state.global_step >= self.jagged_restart_steps
and state.global_step % self.jagged_restart_steps != 0
):
if self.quantized:
if is_main_process() and self.last_full_model != checkpoint_folder:
@@ -320,6 +362,8 @@ def update_weights(
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
if bnb is None:
raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights")
target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)

File diff suppressed because it is too large Load Diff

View File

@@ -285,7 +285,9 @@ def save_trained_model(
)
# Handle ReLoRA early return case
if cfg.relora:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
if hasattr(model, "merge_and_unload") and not (
cfg.load_in_4bit or cfg.load_in_8bit
):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`

View File

@@ -43,14 +43,16 @@ class MultiModalChatDataCollator(DataCollatorMixin):
# Initialize batch
messages = [ex["messages"] for ex in examples]
batch = self.processing_strategy.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
batch = dict(
self.processing_strategy.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
)
# Process the labels

View File

@@ -1044,7 +1044,7 @@ class AxolotlInputConfig(
torch_compile: Literal["auto"] | bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
"description": "Whether to use torch.compile and which backend to use."
},
)
torch_compile_backend: str | None = Field(
@@ -1397,16 +1397,7 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def normalize_attn_implementation(cls, data):
"""Map legacy boolean attention flags to the canonical `attn_implementation`.
`attn_implementation` is the single source of truth on the validated
config. Legacy booleans (`flash_attention: true`, …) are input-only
aliases; this validator warns, maps them to their canonical value, and
strips them from `data` so they cannot be read downstream.
Raises if a canonical `attn_implementation` is set alongside any legacy
boolean — users must pick one.
"""
"""Map legacy boolean attention flags to canonical attn_implementation, warn, then strip."""
if not isinstance(data, dict):
return data

View File

@@ -166,10 +166,10 @@ class SFTDataset(BaseModel):
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
},
)
train_on_eos: Literal["all", "turn", "last"] | None = Field(
train_on_eos: Literal["all", "turn", "last", "none"] | None = Field(
default=None,
json_schema_extra={
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens"
},
)
roles: dict[str, list[str]] | None = Field(

View File

@@ -97,12 +97,7 @@ class CustomSupportedOptimizers(str, Enum):
flash_lion = "flash_lion"
# Canonical values accepted for `attn_implementation`. These are passed to HF
# verbatim via `model.config._attn_implementation`. HF-native backends use HF's
# own names (`flash_attention_2`, `flex_attention`, ...); axolotl-owned backends
# (`xformers`, `sage`, `s2`, `fp8`) register into `ALL_ATTENTION_FUNCTIONS` under
# these exact names. Hub-kernel paths (e.g. `kernels-community/flash-attn3`) are
# not in this set — they pass through the validator via the "/" check.
# Accepted canonical names; hub-kernel paths (containing "/") bypass this set.
CANONICAL_ATTN_IMPLS = frozenset(
{
"eager",
@@ -117,10 +112,7 @@ CANONICAL_ATTN_IMPLS = frozenset(
}
)
# Legacy boolean attention flags → canonical `attn_implementation`. Kept for
# backwards compatibility; the normalizer warns and strips these from the
# validated config. Priority order (first match wins) matches the old priority:
# specific backends beat the generic flash/sdp/eager fallbacks.
# Legacy boolean flags → canonical attn_implementation. Priority: specific before generic.
LEGACY_ATTN_FLAG_TO_IMPL = {
"xformers_attention": "xformers",
"s2_attention": "s2",
@@ -131,9 +123,7 @@ LEGACY_ATTN_FLAG_TO_IMPL = {
"eager_attention": "eager",
}
# Short-form aliases that were accepted by the in-progress branch but are
# rejected going forward. Mapped to canonical names only to produce a helpful
# error message pointing users at the right value.
# Short-form aliases rejected at validation; mapped to canonical names for error messages.
SHORT_FORM_ALIAS_TO_CANONICAL = {
"flash": "flash_attention_2",
"flex": "flex_attention",
@@ -148,18 +138,19 @@ ATTN_IMPLS_SUPPORTING_PACKING = frozenset(
"flex_attention",
"xformers",
"sage",
"kernels-community/flash-attn2",
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
}
)
# Backends that require the flash_attn library (Dao-AILab/flash-attention) for
# axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, ring-FA, ...).
# Backends that require the flash_attn library for axolotl's own monkeypatches.
ATTN_IMPLS_USING_FLASH_LIB = frozenset(
{
"flash_attention_2",
"flash_attention_3",
"s2",
"kernels-community/flash-attn2",
"kernels-community/flash-attn3",
}
)

View File

@@ -103,6 +103,12 @@ class ModelInputConfig(BaseModel):
default=None,
json_schema_extra={"description": "kwargs for model quantization config"},
)
use_onebitllms: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use `onebitllms` for 1.58bit training (only for bitnet models)."
},
)
@field_validator("trust_remote_code")
@classmethod

View File

@@ -6,6 +6,57 @@ from PIL.Image import Resampling
from pydantic import BaseModel, Field, field_validator
class RoleBoundarySpec(BaseModel):
"""One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md."""
role: str = Field(
json_schema_extra={
"description": (
"Role name as it appears in cfg.roles_to_train (e.g. "
"'assistant', 'user', 'system', 'tool', 'ipython')."
)
},
)
start: str = Field(
json_schema_extra={
"description": (
"Literal string that marks the start of this role's span in "
"the rendered chat template. Tokenized via "
"``tokenizer.encode(..., add_special_tokens=False)`` at "
"strategy init."
)
},
)
end: str | None = Field(
default=None,
json_schema_extra={
"description": (
"Literal string that marks the end of this role's span. "
"Set to ``eos_token`` to terminate at the tokenizer's EOS. "
"Leave unset / null to terminate at end-of-sequence."
)
},
)
include_start: bool = Field(
default=False,
json_schema_extra={
"description": (
"Whether the start marker tokens contribute to loss on "
"trainable turns. Default False."
)
},
)
include_end: bool = Field(
default=True,
json_schema_extra={
"description": (
"Whether the end marker tokens contribute to loss on "
"trainable turns (honoring cfg.train_on_eos). Default True."
)
},
)
class MultiModalConfig(BaseModel):
"""Multi-modal configuration subset"""
@@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel):
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
},
)
role_boundaries: list[RoleBoundarySpec] | None = Field(
default=None,
json_schema_extra={
"description": (
"Opt-in override for the MM mask scanner's per-role boundary "
"markers. Non-empty list replaces built-ins wholesale; unset "
"or empty falls back to built-ins. See "
"docs/multimodal_assistant_mask.md."
)
},
)
@field_validator("image_resize_algorithm", mode="before")
@classmethod

View File

@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field(
adapter: str | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model"
"description": "If you want to use a built-in or plugin adapter, or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
@@ -174,6 +174,16 @@ class LoraConfig(BaseModel):
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
adapter = data.get("adapter")
if adapter and adapter not in ("lora", "qlora", "llama-adapter"):
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
if not plugin_manager.supports_adapter(adapter):
raise ValueError(
f"Adapter '{adapter}' is not built in and was not registered by "
"a plugin. Add the plugin that provides this adapter to `plugins:`."
)
return data
@model_validator(mode="after")
@@ -240,8 +250,28 @@ class ReLoRAConfig(BaseModel):
)
relora_prune_ratio: float | None = Field(
default=None,
ge=0.0,
le=1.0,
json_schema_extra={
"description": "threshold for optimizer magnitude when pruning"
"description": (
"Fraction of optimizer state values to zero on each ReLoRA restart. "
"When relora_prune_method='reset' and this is omitted, defaults to "
"0.999 (paper-style near-full reset). For other methods, defaults to 0.9."
)
},
)
relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field(
default="magnitude",
json_schema_extra={
"description": (
"Optimizer state pruning method on each ReLoRA restart. "
"'magnitude' (default) keeps top-k by absolute value; "
"'random' keeps a random subset at relora_prune_ratio; "
"'reset' uses near-full random pruning (default ratio 0.999, "
"honoring relora_prune_ratio when explicitly set). "
"Paper-style recipe: relora_prune_method='reset' with no "
"relora_prune_ratio, equivalent to 'random' with ratio=0.999."
)
},
)
relora_cpu_offload: bool | None = Field(

View File

@@ -183,9 +183,6 @@ class DatasetValidationMixin:
class AttentionValidationMixin:
"""Validation methods related to attention mechanisms."""
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
# is now the single entry point for attention-input mapping and conflict detection.
@model_validator(mode="after")
def check_sample_packing_without_attention(self):
if self.sample_packing and not self.attn_supports_packing:
@@ -629,6 +626,12 @@ class LoRAValidationMixin:
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
@model_validator(mode="after")
def check_onebitllms_lora(self):
if self.use_onebitllms and self.adapter in ["lora", "qlora"]:
raise ValueError("LoRA/QLoRA is not supported with use_onebitllms")
return self
@model_validator(mode="before")
@classmethod
def warn_qlora_zero3_w_use_reentrant(cls, data):
@@ -1475,8 +1478,19 @@ class ComplexValidationMixin:
if self.relora:
if not self.jagged_restart_steps:
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
adapter_supports_relora = self.adapter in ("lora", "qlora")
if self.adapter and not adapter_supports_relora:
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
adapter_supports_relora = plugin_manager.adapter_supports_relora(
self.adapter
)
if not adapter_supports_relora:
raise ValueError(
"cfg.adapter must support ReLoRA to use ReLoRA restart semantics"
)
if self.fsdp or self.fsdp_config:
raise ValueError("fsdp not supported with ReLoRA")

View File

@@ -119,15 +119,49 @@ def download_smollm2_135m_gptq_model():
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
def download_qwen3_half_billion_model():
# download the model (still used as the KD teacher in tests/e2e/integrations/test_kd.py)
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen3_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
def download_tiny_llama_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-llama-50m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_mistral_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-mistral-25m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_mixtral_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-mixtral-30m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_phi_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-phi-64m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_falcon_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-falcon-42m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_qwen2_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen2-129m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_qwen3_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen3-129m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tiny_gemma2_model():
snapshot_download_w_retry("axolotl-ai-co/tiny-gemma2-137m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
@@ -620,7 +654,15 @@ def fixture_min_base_cfg():
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen_2_5_half_billion_model,
download_qwen3_half_billion_model,
download_tiny_llama_model,
download_tiny_mistral_model,
download_tiny_mixtral_model,
download_tiny_phi_model,
download_tiny_falcon_model,
download_tiny_qwen2_model,
download_tiny_qwen3_model,
download_tiny_gemma2_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,

View File

@@ -10,7 +10,10 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
from tests.e2e.utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
)
@pytest.fixture()
@@ -35,13 +38,16 @@ def min_cfg(temp_dir):
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"learning_rate": 5e-4,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"max_steps": 10,
"max_steps": 40,
"warmup_steps": 5,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
@@ -64,11 +70,18 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=2.2,
max_final=2.0,
)
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
@@ -87,13 +100,15 @@ class TestCutCrossEntropyIntegration:
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"max_steps": 10,
"max_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -108,6 +123,13 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.parametrize(
"attention_type",
@@ -136,3 +158,10 @@ class TestCutCrossEntropyIntegration:
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=2.2,
max_final=2.0,
)

View File

@@ -24,7 +24,7 @@ from axolotl.monkeypatch.lora_kernels import (
)
from axolotl.utils.dict import DictDefault
MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
DEVICE = "cuda"
DTYPE = torch.bfloat16

View File

@@ -1,23 +1,22 @@
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -30,19 +29,13 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
class TestDistMuon:
@@ -52,7 +45,7 @@ class TestDistMuon:
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -63,11 +56,12 @@ class TestDistMuon:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"learning_rate": 2e-3,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
@@ -82,6 +76,9 @@ class TestDistMuon:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -109,7 +106,7 @@ class TestDistMuon:
def test_lora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -122,14 +119,15 @@ class TestDistMuon:
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_dropout": 0.0,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"learning_rate": 2e-3,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
@@ -144,6 +142,9 @@ class TestDistMuon:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)

View File

@@ -1,24 +1,23 @@
"""Test module for FSDP1 multi-GPU functionality."""
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir
from tests.e2e.utils import check_tensorboard_loss_decreased
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
class TestFSDP1:
@@ -56,7 +49,7 @@ class TestFSDP1:
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -67,11 +60,12 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -87,6 +81,9 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -126,7 +123,7 @@ class TestFSDP1:
def test_lora_sft(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -140,14 +137,15 @@ class TestFSDP1:
"load_in_4bit": adapter_config["load_in_4bit"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_dropout": 0.0,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -163,6 +161,9 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -190,7 +191,7 @@ class TestFSDP1:
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
@@ -203,11 +204,11 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -223,6 +224,9 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)
@@ -262,7 +266,7 @@ class TestFSDP1:
def test_dpo_lora(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"load_in_4bit": adapter_config["load_in_4bit"],
"rl": "dpo",
"chat_template": "chatml",
@@ -281,11 +285,11 @@ class TestFSDP1:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -301,6 +305,9 @@ class TestFSDP1:
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": "auto",
"tf32": True,
}

View File

@@ -1,24 +1,23 @@
"""Test module for FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
"""Verify that training completed successfully artifacts, no-NaN, loss
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
"""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)
class TestFSDP2:
@@ -57,7 +50,7 @@ class TestFSDP2:
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -68,11 +61,12 @@ class TestFSDP2:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -86,6 +80,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -114,7 +111,7 @@ class TestFSDP2:
def test_lora_sft(self, temp_dir, peft_use_dora):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -128,14 +125,15 @@ class TestFSDP2:
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_dropout": 0.0,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -149,6 +147,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
# explicitly disable LORA kernels, as they may be auto-enabled
"lora_mlp_kernel": False,
@@ -180,7 +181,7 @@ class TestFSDP2:
def test_lora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -195,11 +196,12 @@ class TestFSDP2:
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -213,6 +215,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
@@ -243,7 +248,7 @@ class TestFSDP2:
def test_qlora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -257,14 +262,15 @@ class TestFSDP2:
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_dropout": 0.0,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -278,6 +284,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
}
)
@@ -305,7 +314,7 @@ class TestFSDP2:
def test_qlora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -321,11 +330,12 @@ class TestFSDP2:
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 80,
"warmup_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -339,6 +349,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
@@ -370,7 +383,7 @@ class TestFSDP2:
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
@@ -383,11 +396,11 @@ class TestFSDP2:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -401,6 +414,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)
@@ -428,7 +444,7 @@ class TestFSDP2:
def test_dpo_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"rl": "dpo",
"chat_template": "chatml",
@@ -445,11 +461,11 @@ class TestFSDP2:
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"max_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-3,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
@@ -463,6 +479,9 @@ class TestFSDP2:
"reshard_after_forward": True,
},
"use_tensorboard": True,
"seed": 42,
"sample_packing": True,
"pad_to_sequence_len": True,
}
)

View File

@@ -40,7 +40,7 @@ def _run_training(temp_dir, cfg):
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "Qwen/Qwen3-0.6B",
"base_model": "axolotl-ai-co/tiny-qwen3-129m",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [

View File

@@ -8,7 +8,7 @@ from accelerate.test_utils import execute_subprocess_async, get_torch_dist_uniqu
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
class TestTensorParallel:
@@ -21,7 +21,7 @@ class TestTensorParallel:
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
@@ -63,6 +63,6 @@ class TestTensorParallel:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
check_tensorboard_loss_decreased(
temp_dir + "/runs", max_initial=5.0, max_final=4.7
)

View File

@@ -32,12 +32,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"name": "axolotl-ai-co/tiny-mistral-25m",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"name": "axolotl-ai-co/tiny-qwen2-129m",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -47,7 +47,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"name": "axolotl-ai-co/tiny-gemma2-137m",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -159,7 +159,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Gemma2ForCausalLM",
"axolotl-ai-co/tiny-gemma2-137m",
dtype=torch.float16,
device_map="cuda:0",
)

View File

@@ -4,14 +4,16 @@ E2E tests for falcon
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestFalconPatched(unittest.TestCase):
@@ -19,13 +21,12 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
@@ -47,17 +48,20 @@ class TestFalconPatched(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -66,14 +70,20 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.05,
@@ -88,17 +98,20 @@ class TestFalconPatched(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -107,3 +120,10 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -9,7 +9,12 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
require_torch_2_6_0,
with_temp_dir,
)
class TestMistral(unittest.TestCase):
@@ -22,7 +27,7 @@ class TestMistral(unittest.TestCase):
def test_lora_packing(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -45,17 +50,20 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -64,12 +72,19 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.5,
max_final=4.3,
)
@with_temp_dir
def test_ft_packing(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -86,17 +101,20 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -105,3 +123,10 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.5,
max_final=4.3,
)

View File

@@ -9,7 +9,11 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestMixtral(unittest.TestCase):
@@ -21,8 +25,7 @@ class TestMixtral(unittest.TestCase):
def test_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -30,7 +33,7 @@ class TestMixtral(unittest.TestCase):
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_dropout": 0.0,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {},
@@ -41,17 +44,21 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 3e-3,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -60,13 +67,19 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=6.0,
max_final=4.7,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -79,17 +92,21 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"learning_rate": 5e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -98,3 +115,10 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -22,8 +22,7 @@ class TestModelPatches(unittest.TestCase):
def test_mixtral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -57,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -9,7 +9,11 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestPhiMultipack(unittest.TestCase):
@@ -21,7 +25,7 @@ class TestPhiMultipack(unittest.TestCase):
def test_ft_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "axolotl-ai-co/tiny-phi-64m",
"model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
@@ -43,17 +47,20 @@ class TestPhiMultipack(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"max_steps": 50,
"logging_steps": 1,
"eval_steps": 50,
"save_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -63,12 +70,19 @@ class TestPhiMultipack(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)
@with_temp_dir
def test_qlora_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "axolotl-ai-co/tiny-phi-64m",
"model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
@@ -94,17 +108,20 @@ class TestPhiMultipack(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"max_steps": 50,
"logging_steps": 1,
"eval_steps": 50,
"save_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -114,3 +131,10 @@ class TestPhiMultipack(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=6.0,
max_final=4.7,
)

View File

@@ -18,7 +18,7 @@ from transformers import AutoModelForCausalLM
# Import the actual trainer methods we want to test
from axolotl.core.trainers.grpo.async_trainer import AsyncGRPOTrainer
MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
def _fix_patched_attention(model):

View File

@@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase):
],
"warmup_steps": 10,
"num_epochs": 2,
"max_steps": 105, # at least 2x relora_steps
"max_steps": 105, # at least 2x restart cadence
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
"Relora model checkpoint not found"
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
)
@with_temp_dir
def test_relora_reset_method(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora": True,
"jagged_restart_steps": 50,
"jagged_restart_warmup_steps": 10,
"jagged_restart_anneal_steps": 10,
"relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored)
"relora_prune_method": "reset",
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"warmup_steps": 10,
"num_epochs": 2,
"max_steps": 105,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,

View File

@@ -4,14 +4,16 @@ E2E tests for falcon
import unittest
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestFalcon(unittest.TestCase):
@@ -19,13 +21,12 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -49,17 +50,21 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -69,14 +74,20 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -104,17 +115,21 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -124,14 +139,20 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
"flash_attention": True,
"base_model": "axolotl-ai-co/tiny-falcon-42m",
"flash_attention": False,
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
@@ -145,17 +166,23 @@ class TestFalcon(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"sample_packing": True,
"pad_to_sequence_len": True,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 5e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 80,
"eval_steps": 80,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -165,3 +192,10 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=5.0,
max_final=4.7,
)

View File

@@ -11,7 +11,11 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestMistral(unittest.TestCase):
@@ -23,7 +27,7 @@ class TestMistral(unittest.TestCase):
def test_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -45,16 +49,18 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -64,12 +70,19 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=4.5,
max_final=4.3,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"base_model": "axolotl-ai-co/tiny-mistral-25m",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,
@@ -85,16 +98,18 @@ class TestMistral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -108,3 +123,10 @@ class TestMistral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=4.5,
max_final=4.3,
)

View File

@@ -12,7 +12,11 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestMixtral(unittest.TestCase):
@@ -24,8 +28,7 @@ class TestMixtral(unittest.TestCase):
def test_qlora_w_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -51,16 +54,18 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -74,13 +79,19 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -106,16 +117,18 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -129,13 +142,19 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sequence_len": 1024,
"adapter": "lora",
@@ -160,16 +179,18 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -187,13 +208,19 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": False,
"sequence_len": 1024,
"adapter": "lora",
@@ -218,16 +245,18 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
@@ -245,13 +274,19 @@ class TestMixtral(unittest.TestCase):
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,
@@ -263,16 +298,18 @@ class TestMixtral(unittest.TestCase):
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"save_first_step": False,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -286,3 +323,10 @@ class TestMixtral(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)

View File

@@ -13,6 +13,7 @@ from axolotl.utils.dict import DictDefault
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
@@ -243,20 +244,18 @@ class TestCustomOptimizers(unittest.TestCase):
def test_came_pytorch(self, temp_dir):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "axolotl-ai-co/tiny-llama-50m",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_dropout": 0.0,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -265,16 +264,22 @@ class TestCustomOptimizers(unittest.TestCase):
},
],
"num_epochs": 1,
"sample_packing": True,
"pad_to_sequence_len": True,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 1e-4,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"max_steps": 80,
"warmup_steps": 5,
"logging_steps": 1,
"lr_scheduler": "cosine",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
@@ -284,6 +289,13 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=10,
final_window=10,
max_initial=4.0,
max_final=3.0,
)
@require_torch_2_7_0

View File

@@ -9,7 +9,11 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import (
check_model_output_exists,
check_tensorboard_loss_decreased,
with_temp_dir,
)
class TestPhi(unittest.TestCase):
@@ -21,7 +25,7 @@ class TestPhi(unittest.TestCase):
def test_phi_ft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "axolotl-ai-co/tiny-phi-64m",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
@@ -41,18 +45,22 @@ class TestPhi(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"learning_rate": 2e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"max_steps": 10,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -61,12 +69,19 @@ class TestPhi(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)
@with_temp_dir
def test_phi_qlora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model": "axolotl-ai-co/tiny-phi-64m",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
@@ -90,18 +105,22 @@ class TestPhi(unittest.TestCase):
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"learning_rate": 2e-4,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"max_steps": 10,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 50,
"warmup_steps": 5,
"logging_steps": 1,
"save_steps": 50,
"eval_steps": 50,
"bf16": "auto",
"save_first_step": False,
"use_tensorboard": True,
"seed": 42,
}
)
cfg = validate_config(cfg)
@@ -110,3 +129,10 @@ class TestPhi(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard_loss_decreased(
temp_dir + "/runs",
initial_window=5,
final_window=5,
max_initial=5.0,
max_final=4.7,
)

View File

@@ -18,7 +18,7 @@ class TestPreprocess:
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [

View File

@@ -45,7 +45,7 @@ def _get_fake_quant_config_dtype(config):
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
"axolotl-ai-co/tiny-qwen2-129m",
device_map="auto",
dtype=torch.bfloat16,
)

View File

@@ -17,7 +17,7 @@ class TestE2eQwen:
Test cases for qwen models
"""
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
@pytest.mark.parametrize("base_model", ["axolotl-ai-co/tiny-qwen2-129m"])
def test_dpo(self, base_model, temp_dir):
cfg = DictDefault(
{

View File

@@ -199,6 +199,106 @@ def check_tensorboard(
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
def check_tensorboard_loss_decreased(
temp_run_dir: str,
tag: str | None = None,
initial_window: int = 1,
final_window: int = 1,
min_delta: float | None = None,
max_initial: float | None = None,
max_final: float | None = None,
max_loss_ratio: float = 0.95,
) -> None:
"""Check that training actually learned — loss went down and stayed in
a sensible range.
Used with the tiny ``axolotl-ai-co/tiny-*`` CI models, where pretraining
was brief enough that final loss won't clear the absolute thresholds used
for 135M+ models — but the training pipeline should still behave.
``train/train_loss`` is only logged once (end-of-training aggregate). The
per-step tag is ``train/loss`` for SFT/LM trainers and may vary across
trainers (e.g. DPO). When ``tag`` is None we try common per-step tags in
order and use the first with enough samples.
Two kinds of regression we guard against:
1. **Loss blew up.** A silent bug (e.g. broken label masking) can start
training at an absurdly high loss. ``max_initial`` / ``max_final``
assert the measured means stay at-or-below bounds measured from a
known-good run. Both are optional but strongly encouraged — loss
going *down* from a bad starting scale still looks like "learning."
2. **Loss didn't go down enough.** ``max_loss_ratio`` (default 0.95)
requires ``final <= initial * ratio``. A default below 1.0 means the
final window mean must sit at least 5% below the initial window mean
— real learning, not noise that happened to land below start. Only
raise this for configs where a smaller drop is expected *and*
documented (e.g. DPO with near-trivial pairs); in that case you are
intentionally weakening the test.
``min_delta`` is optional; when set, additionally requires
``final + min_delta <= initial`` — use for configs with enough signal
to demand a specific minimum absolute drop.
"""
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars
if tag is None:
candidates = ["train/loss", "train/train_loss"]
else:
candidates = [tag]
required = initial_window + final_window
chosen_tag, values = None, None
for candidate in candidates:
sub = df[df.tag == candidate]
if len(sub) >= required:
chosen_tag = candidate
values = sub.value.values
break
available = sorted({t for t in df.tag.unique() if "loss" in t.lower()})
assert values is not None, (
f"None of the tags {candidates} had ≥{required} logged steps. "
f"Loss tags present: {available}"
)
initial = float(values[:initial_window].mean())
final = float(values[-final_window:].mean())
print(
f"[check_tensorboard_loss_decreased] tag={chosen_tag} n={len(values)} "
f"initial_mean{initial_window}={initial:.4f} final_mean{final_window}={final:.4f}"
)
assert final > 1e-5, "Expected loss to be greater than zero"
assert final <= initial * max_loss_ratio, (
f"Loss did not decrease for {chosen_tag}: "
f"initial(mean of first {initial_window})={initial:.4f}, "
f"final(mean of last {final_window})={final:.4f}, "
f"ratio={final / initial:.4f} (max allowed {max_loss_ratio}). "
f"Expected final <= initial — training did not learn."
)
if min_delta is not None:
assert final + min_delta <= initial, (
f"Expected loss to decrease by at least {min_delta} for {chosen_tag}: "
f"initial={initial:.4f}, final={final:.4f}, delta={initial - final:.4f}"
)
if max_initial is not None:
assert initial <= max_initial, (
f"Initial loss {initial:.4f} is above the expected max {max_initial}. "
f"Absolute scale is wrong — probably a silent regression "
f"(e.g. bad label masking) that bumped the starting point."
)
if max_final is not None:
assert final <= max_final, (
f"Final loss {final:.4f} is above the expected max {max_final}. "
f"Absolute scale is wrong — probably a silent regression "
f"(e.g. bad label masking) that bumped the endpoint."
)
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training

View File

@@ -0,0 +1,160 @@
"""Integration tests for the MoRA / ReMoRA adapter path."""
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import PluginManager
from axolotl.integrations.mora import plugin as mora_plugin
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.dict import DictDefault
class TestMoraAdapterLoading:
"""MoRA adapter selection and config wiring."""
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
calls = []
def fake_load_lora(*args, **kwargs):
calls.append((args, kwargs))
return args[0], "adapter-config"
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
_, config = load_adapter(model, cfg, "mora")
assert config == "adapter-config"
assert calls[0][1]["config_only"] is False
def test_mora_plugin_raises_when_peft_missing_support(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ImportError, match="MoRA support"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_rejects_quantized_base_model(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"load_in_4bit": True,
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ValueError, match="full-precision base model"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {
"use_mora": True,
"mora_type": "rope",
},
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module, "get_peft_model", Mock(return_value=fake_model)
)
_, config = load_adapter(model, cfg, "mora", config_only=True)
assert captured["use_mora"] is True
assert captured["mora_type"] == 6
assert captured["task_type"].name == "CAUSAL_LM"
assert config is not None
assert config.use_mora is True
assert config.mora_type == 6
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
"lora_model_dir": "adapter-checkpoint",
"lora_on_cpu": False,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
class FakeLoraConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class FakePeftModel:
def print_trainable_parameters(self):
pass
def named_parameters(self):
return []
from_pretrained = Mock(return_value=FakePeftModel())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module.PeftModel,
"from_pretrained",
from_pretrained,
)
peft_model, config = load_adapter(model, cfg, "mora")
assert isinstance(peft_model, FakePeftModel)
assert config.use_mora is True
from_pretrained.assert_called_once()
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
assert from_pretrained.call_args.kwargs["is_trainable"] is True

View File

@@ -0,0 +1,73 @@
"""Core adapter plugin registry tests."""
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class FakeAdapterPlugin(BasePlugin):
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)]
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "fake-adapter":
return {}
return {"fake_kwarg": "from-plugin"}
class TestAdapterPluginRegistry:
def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "fake-adapter",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(adapter_module, "get_peft_model", Mock())
_, config = load_adapter(model, cfg, "fake-adapter", config_only=True)
assert config is not None
assert captured["fake_kwarg"] == "from-plugin"
assert captured["task_type"].name == "CAUSAL_LM"
def test_unknown_adapter_error_mentions_plugin_registry(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault({"adapter": "missing-adapter"})
with pytest.raises(NotImplementedError, match="registered by a plugin"):
load_adapter(model, cfg, "missing-adapter")
def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg):
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
cfg = min_base_cfg | DictDefault(
{
"adapter": "fake-adapter",
"relora": True,
"jagged_restart_steps": 100,
}
)
validated = validate_config(cfg)
assert validated.adapter == "fake-adapter"
assert validated.relora is True

View File

@@ -0,0 +1,186 @@
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
import math
import pytest
import torch
import torch.nn as nn
from axolotl.monkeypatch.relora import (
magnitude_pruning_,
random_pruning_,
reset_optimizer,
)
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
def _build_optimizer_with_state(seed: int = 0):
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
torch.manual_seed(seed)
lora_a = nn.Parameter(torch.randn(8, 32))
lora_b = nn.Parameter(torch.randn(32, 8))
extra = nn.Parameter(torch.randn(64, 32))
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
for _ in range(2):
loss = (
(lora_a * torch.randn_like(lora_a)).sum()
+ (lora_b * torch.randn_like(lora_b)).sum()
+ (extra * torch.randn_like(extra)).sum()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return optimizer, lora_a, lora_b, extra
def test_reset_optimizer_only_touches_reset_params():
"""State for params NOT in reset_params must be byte-identical after reset."""
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
def test_reset_optimizer_actually_prunes_lora_state():
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
for param in (lora_a, lora_b):
for key in ADAM_KEYS:
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
assert zero_frac >= 0.85
@pytest.mark.parametrize(
"method,ratio,expected_zero_frac",
[
("magnitude", 0.9, 0.9),
("magnitude", 0.99, 0.99),
("random", 0.9, 0.9),
("random", 0.5, 0.5),
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
("reset", 0.9, 0.9),
("reset", 0.5, 0.5),
],
)
def test_prune_methods(method, ratio, expected_zero_frac):
"""Each method zeros approximately the expected fraction."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method=method,
prune_ratio=ratio,
)
total = 0
zeros = 0
for param in (lora_a, lora_b):
for key in ADAM_KEYS:
tensor = optimizer.state[param][key]
total += tensor.numel()
zeros += (tensor == 0).sum().item()
actual = zeros / total
tolerance = 0.02 if method == "magnitude" else 0.05
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
def test_reset_optimizer_skips_keys_not_in_state_keys():
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=["exp_avg"],
prune_method="magnitude",
prune_ratio=0.9,
)
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
def test_reset_optimizer_handles_param_with_empty_state():
"""Params with no optimizer state are skipped silently."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
orphan = nn.Parameter(torch.randn(4, 4))
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b, orphan],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
assert orphan not in optimizer.state or not optimizer.state[orphan]
def test_unknown_prune_method_raises():
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
with pytest.raises(ValueError, match="Unknown prune_method"):
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="bogus", # type: ignore[arg-type]
prune_ratio=0.9,
)
def test_pruning_helpers_are_inplace():
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
tensor = torch.randn(64)
ptr_before = tensor.data_ptr()
magnitude_pruning_(tensor, 0.5)
assert tensor.data_ptr() == ptr_before
tensor = torch.randn(64)
ptr_before = tensor.data_ptr()
random_pruning_(tensor, 0.5)
assert tensor.data_ptr() == ptr_before
def test_pruning_helpers_support_uint8_tensors():
"""Both pruning helpers must work on uint8 optimizer state tensors."""
tensor = torch.arange(1, 129, dtype=torch.uint8)
magnitude_pruning_(tensor, 0.9)
assert tensor.dtype == torch.uint8
magnitude_zero_frac = (tensor == 0).float().mean().item()
assert 0.85 <= magnitude_zero_frac <= 0.95
tensor = torch.arange(1, 129, dtype=torch.uint8)
with torch.random.fork_rng(devices=[]):
torch.manual_seed(1234)
random_pruning_(tensor, 0.9)
assert tensor.dtype == torch.uint8
random_zero_frac = (tensor == 0).float().mean().item()
assert 0.85 <= random_zero_frac <= 0.95

View File

@@ -1,11 +1,5 @@
"""Tests for attn_implementation: input normalization, canonical-value
acceptance, capability flags, backend registration, and downstream validators.
Test classes are organized by feature concern, not by the layer of the schema
where the behavior is implemented (classmethod normalizer vs. field validator
vs. full `validate_config` pipeline). Each class covers a single contract end
to end, dropping into the lower layer only where it gives faster or sharper
coverage of an isolated branch.
"""Tests for attn_implementation: normalization, canonical-value acceptance,
capability flags, backend registration, and downstream validators.
"""
import logging
@@ -53,12 +47,7 @@ def _xformers_available():
class TestCapabilityTables:
"""Backend capability classification.
Asserts both the static frozensets in `enums.py` and the `computed_field`
properties on a validated config read consistently from those tables, and
that user YAML cannot override the computed flags.
"""
"""Backend capability classification via frozensets and computed_field properties."""
@pytest.mark.parametrize(
"impl",
@@ -359,8 +348,6 @@ class TestGemma4HybridMode:
assert result["attn_implementation"] == "flash_attention_2"
def test_non_fa2_raises(self):
"""The hybrid path requires FA2 under the hood — any other backend is
a configuration error."""
with pytest.raises(
ValueError, match="requires attn_implementation=flash_attention_2"
):
@@ -370,11 +357,7 @@ class TestGemma4HybridMode:
class TestSamplePackingValidation:
"""`sample_packing` requires a varlen-capable backend.
Non-varlen backends (eager, sdpa) warn about cross-sample contamination;
s2 raises outright because shifted-sparse attention has no varlen path.
"""
"""`sample_packing` warns for non-varlen backends; s2 raises outright."""
def test_eager_warns(self, min_base_cfg, caplog):
cfg = min_base_cfg | DictDefault(

View File

@@ -0,0 +1,163 @@
"""
Regression tests for MultiModalChatDataCollator shape contracts.
Guard against the transformers 5.x breakage where apply_chat_template's
own `return_dict` parameter (default False) caused it to return the raw
input_ids tensor instead of the full BatchFeature dict, leading to
IndexError: too many indices for tensor of dimension 2
when downstream code did batch["input_ids"] on the resulting tensor.
"""
from unittest.mock import MagicMock, patch
import pytest
import torch
from transformers import BatchFeature
@pytest.fixture(name="mock_processor")
def fixture_mock_processor():
"""
A mock processor whose apply_chat_template returns a BatchFeature
when called with return_dict=True (the correct call convention),
or a raw input_ids tensor when called without return_dict=True
(the broken call convention that the bug introduced).
"""
processor = MagicMock()
processor.tokenizer = MagicMock()
processor.tokenizer.pad_token_id = 0
processor.image_token = "<|image|>"
processor.tokenizer.convert_tokens_to_ids = MagicMock(return_value=128256)
batch_size, seq_len = 2, 16
input_ids = torch.ones(batch_size, seq_len, dtype=torch.long)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
batch_feature = BatchFeature(
data={
"input_ids": input_ids,
"attention_mask": attention_mask,
}
)
def _apply_chat_template(*args, **kwargs):
if kwargs.get("return_dict", False):
return batch_feature
# Simulate transformers 5.x default behaviour: returns out["input_ids"]
return input_ids
processor.apply_chat_template = MagicMock(side_effect=_apply_chat_template)
processor.chat_template = None
return processor
@pytest.fixture(name="mock_processing_strategy")
def fixture_mock_processing_strategy(mock_processor):
from axolotl.processing_strategies import ProcessingStrategy
strategy = ProcessingStrategy(processor=mock_processor)
return strategy
class TestMultiModalChatDataCollatorShapeContract:
"""
Verify that MultiModalChatDataCollator.process_rows returns a dict with
2-D input_ids and labels, not a raw tensor. This is the shape contract
that process_labels depends on.
"""
def _make_collator(self, mock_processing_strategy):
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
tokenizer = mock_processing_strategy.processor.tokenizer
return MultiModalChatDataCollator(
tokenizer=tokenizer,
processing_strategy=mock_processing_strategy,
)
def _make_examples(self):
return [
{
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
}
]
def test_process_rows_returns_dict(self, mock_processing_strategy):
"""batch must be a dict, not a raw tensor."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert isinstance(batch, dict), (
"process_rows must return a dict (BatchFeature), not a raw tensor. "
"If it returns a tensor, apply_chat_template was called without "
"return_dict=True at the top level."
)
def test_process_rows_input_ids_shape(self, mock_processing_strategy):
"""batch['input_ids'] must be a 2-D tensor (batch, seq_len)."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert "input_ids" in batch
assert isinstance(batch["input_ids"], torch.Tensor)
assert batch["input_ids"].ndim == 2, (
f"input_ids must be 2-D (batch, seq_len), got shape {batch['input_ids'].shape}"
)
def test_process_rows_labels_shape(self, mock_processing_strategy):
"""batch['labels'] must be a 2-D tensor matching input_ids shape."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
batch = collator.process_rows(examples)
assert "labels" in batch
assert isinstance(batch["labels"], torch.Tensor)
assert batch["labels"].ndim == 2
assert batch["labels"].shape == batch["input_ids"].shape
def test_apply_chat_template_called_with_return_dict_true(
self, mock_processing_strategy
):
"""apply_chat_template must be called with return_dict=True as a keyword arg."""
collator = self._make_collator(mock_processing_strategy)
examples = self._make_examples()
with patch.object(
mock_processing_strategy,
"__call__",
return_value=examples,
):
collator.process_rows(examples)
call_kwargs = (
mock_processing_strategy.processor.apply_chat_template.call_args.kwargs
)
assert call_kwargs.get("return_dict") is True, (
"apply_chat_template must be called with return_dict=True as a top-level "
"keyword argument (not inside processor_kwargs). In transformers 5.x, "
"apply_chat_template has its own return_dict param (default False) that "
"controls whether it returns the full BatchFeature or just input_ids."
)

View File

@@ -1,16 +1,8 @@
"""Guard the attn_implementation source-of-truth invariant.
"""Enforce attn_implementation as the single source of truth.
`cfg.attn_implementation` is the single source of truth for the attention
backend on the validated config. Legacy boolean flags (`flash_attention`,
`sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`,
`s2_attention`, `eager_attention`) are input-only deprecated aliases — they
are stripped from `data` by `normalize_attn_implementation` and must never be
read downstream.
This test greps `src/` and fails if it finds a `cfg.<legacy>_attention` read.
If you're here because this test failed, migrate the read site to
`cfg.attn_implementation` or one of the `attn_supports_packing /
attn_uses_flash_lib / attn_needs_dtype_cast` computed capability flags.
Fails if src/ contains a cfg.<legacy>_attention read. Migrate offending sites
to cfg.attn_implementation or the attn_supports_packing/attn_uses_flash_lib/
attn_needs_dtype_cast computed flags.
"""
from __future__ import annotations

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,100 @@
"""Validation tests for the MoRA / ReMoRA integration."""
import pytest
from axolotl.integrations.mora import MoraType
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
class TestMoraValidation:
"""MoRA-specific config validation."""
def test_mora_block_round_trips(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.adapter == "mora"
assert validated.mora.use_mora is True
assert validated.mora.mora_type == MoraType.ROPE
def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 1,
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.mora.mora_type == MoraType.SHARING
def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 2,
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="mora_type"):
validate_config(cfg)
def test_remora_uses_core_relora_fields(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"jagged_restart_steps": 2000,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.relora is True
assert validated.jagged_restart_steps == 2000
def test_remora_still_requires_core_restart_steps(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="jagged_restart_steps"):
validate_config(cfg)