* Fix Axolotl ReLoRA optimizer reset scope
* fix: make relora reset method honor relora_prune_ratio
When relora_prune_method='reset' and relora_prune_ratio is explicitly
set, the ratio was silently ignored and replaced with the hardcoded
_FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to
ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset
and 0.9 for other methods. reset_optimizer now uses the same random
pruning path for both 'random' and 'reset'.
Also consolidate three-layer default mismatch: schema default for
relora_prune_method is now 'magnitude' (single canonical source);
dataclass defaults for both fields changed to None to eliminate the
conflicting fallback layer.
Tests updated: removed the test case that verified the old broken
behavior (reset ignoring ratio), added two cases proving reset honors
the passed ratio. E2E reset fixture now uses ratio=0.5 to make it
unambiguous that the ratio is honored.
* Fix ReLoRA uint8 pruning regression
---------
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries
Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.
What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
strategy declares per-role start/end markers via
`_build_role_boundaries`; the shared scanner honors
`train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
(previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
(Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
fallback) retain legacy behavior and emit a one-shot warning. Users
can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
token between user-end and assistant-start via `include_end=False`
+ scanner rewind.
See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* feat: systemic multimodal assistant-only loss masking + cfg.role_boundaries
Fixes silent ignoring of `cfg.train_on_inputs` / `cfg.roles_to_train` /
`cfg.train_on_eos` in the multimodal training path. Before this branch,
only Gemma 3n honored these knobs; every other VLM trained on the full
sequence regardless of config. Also adds `cfg.role_boundaries` YAML
override so users can declare per-role markers without subclassing.
What changed
------------
- `ProcessingStrategy` gains a declarative boundary scanner. Each
strategy declares per-role start/end markers via
`_build_role_boundaries`; the shared scanner honors
`train_on_inputs` / `roles_to_train` / `train_on_eos` (incl. "last").
- New per-template strategies: Gemma 4, Llama 3.2 Vision, Llama 4,
Pixtral, Mistral V7 Tekken.
- Refactored: Gemma 3 (previously no role masking), Gemma 3n
(previously ad-hoc scanner, now shared).
- Strategies whose boundary tokens couldn't be verified offline
(Voxtral, SmolVLM2, Mistral3, InternVL, GLM4V, llava/lfm2vl
fallback) retain legacy behavior and emit a one-shot warning. Users
can enable masking on them via `cfg.role_boundaries`.
- Pixtral / Mistral V7 Tekken correctly handle the shared `[/INST]`
token between user-end and assistant-start via `include_end=False`
+ scanner rewind.
See `docs/multimodal_assistant_mask.md` for the full audit table,
root-cause analysis, and design rationale.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* docs+types: address CodeRabbit nitpicks on PR #7
- builders/causal.py: add inline NOTE that multi-dataset configs reuse
the first dataset's masking knobs (roles_to_train / train_on_eos) for
all datasets — heterogeneous per-dataset overrides are not supported
in the MM path today.
- processing_strategies.py: annotate inner scanner helpers
_match_prefix and _find_end with explicit types (Tensor, int,
list[int] → bool / tuple[int, bool]) for readability.
- docs/multimodal_assistant_mask.md: renumber the "Commits on this
branch" list to 1-7 consecutive (previously skipped 3).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* fix(mm-mask): address two CodeRabbit findings on PR #7
1. Schema rejected `train_on_eos: "none"` despite the scanner honoring it.
`_VALID_TRAIN_ON_EOS` accepts "none" and the design doc lists it, but
`SFTDataset.train_on_eos` was `Literal["all", "turn", "last"]`, so YAML
users hit a pydantic ValidationError at config load. Added "none" to
the Literal and updated the description.
2. `cfg.role_boundaries: []` had split-personality semantics: the strategy
ctor treated it as "replace built-ins with empty" while the collator
plumbing treated it as "unset", and both the design doc and the
MultiModalConfig schema help text promised wholesale replacement for
any set value. Aligned on opt-in semantics across all four surfaces —
a non-empty list replaces built-ins wholesale; unset or `[]` falls back
to built-ins. Rationale: honoring `[]` literally yields all-masked
labels and zero gradient, which is almost always a typo or leftover
rather than a deliberate user action. Users who want to disable role
masking should unset the field or use `train_on_inputs: true`.
Also sharpened the fallback one-shot warning for strategies without
built-in boundaries: names the consequence ("only pad and media tokens
are masked, every other token contributes to loss") and points users
at `cfg.role_boundaries` + docs/multimodal_assistant_mask.md instead
of "see axolotl/processing_strategies.py for how to declare
boundaries."
Files:
- src/axolotl/utils/schemas/datasets.py: Literal adds "none"
- src/axolotl/processing_strategies.py: ctor truthiness check on
role_boundaries_override; sharpened fallback warning
- src/axolotl/utils/schemas/multimodal.py: role_boundaries description
now calls out opt-in + empty-list fallback semantics
- docs/multimodal_assistant_mask.md: same clarification in the Semantics
block; updated the fallback-path detection paragraph to quote the new
warning text
- tests/test_processing_strategies.py: +2 regressions
(test_sft_dataset_schema_accepts_all_supported_train_on_eos_values,
test_empty_role_boundaries_override_falls_back_to_builtin); 63/63 pass
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* doc cleanup
* fix(mm-mask): CodeRabbit findings + lint fix on PR #3625
Pre-commit failure: trailing newline missing on
docs/multimodal_assistant_mask.md (end-of-file-fixer hook).
Six CodeRabbit findings addressed:
1. Scanner: non-trainable role's end marker ignored ``include_end``.
Under ``train_on_eos="all"``, the shared ``[/INST]`` token (user-end
with ``include_end=False``, intentionally re-matched as assistant-start)
leaked into loss via the user branch on Pixtral / Mistral V7 Tekken.
Fix: gate the non-trainable branch on ``best_match.include_end`` to
mirror the trainable branch.
2. Gemma3 ``boi_token`` lookup used ``tokenizer.special_tokens_map.get("boi_token")``,
which never fires on real checkpoints (``special_tokens_map`` only
holds HF's standard slots — bos/eos/pad/unk/...). Swap to direct
attribute read ``getattr(tokenizer, "boi_token", None)``, matching
what ``transformers.models.gemma3.processing_gemma3`` itself does.
Updated the ``_gemma_tokenizer`` test fixture to mirror real-model
shape so the test exercises the production code path.
3. GLM dispatcher only registered ``Glm46VProcessor`` (GLM-4.6V /
GLM-4.7V). Real ``Glm4vProcessor`` (GLM-4V / GLM-4.1V) users fell
through to the base fallback. Both processors ship identical
media-token markers, so register both under the shared
``Glm4vProcessingStrategy`` with independent try/except import blocks.
Updated class docstring. +2 dispatcher regressions.
4. Gemma3 ``process_labels`` hardcoded 262144 for the soft image token.
Resolve dynamically via ``tokenizer.convert_tokens_to_ids("<image_soft_token>")``
with unk-id guard; fall back to 262144 only if the string isn't in
vocab. Mirrors ``Gemma4ProcessingStrategy.process_labels`` pattern.
5. ``build_collator`` was called twice per ``build()`` (eval + train
passes), producing two identical ``MM collator: ...`` INFO banners on
startup. Gate the log on ``is_eval=False`` so only the training pass
emits it.
6. Removed unused ``_mistral_common_stub`` pytest fixture (13 refs → 0,
always returned ``None``; the dispatcher already handles missing
``mistral_common`` via lazy import + ``try/except``). Added
``test_scanner_train_on_eos_all_with_non_trainable_include_end_false``
— a focused scanner-level lock-in for finding #1, independent of any
specific VLM strategy.
Test count: 63 → 68 passing. Local ``pre-commit run --all-files`` green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* chore(mm-mask): hoist .tolist() out of scanner; shorten comments/docstrings
- Scanner perf: convert labels[i] to a Python list once per row so
_match_prefix / _find_end operate on list slices instead of
re-materializing Tensor slices via .tolist() on every probe. Cuts
O(n*boundaries) CPython↔C boundary crossings per batch.
- Markdown lint (MD001, MD040): promote two h3 section headings to h2
under the h1; add `text` language to the verify-at-runtime fenced block.
- Shorten verbose comments/docstrings added in recent commits to
bare-minimum "why" notes matching the repo's existing style.
68/68 tests, 8/8 pre-commit hooks still pass.
* memory clean up for fsdp full state dict
* Update src/axolotl/monkeypatch/accelerate/fsdp2.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* upgrade to torchao 0.17.0
* chore: lint
* refactor attention handling
* replace legacy attention boolean flags with capability properties
Replace checks with capability-based properties derived from attn_implementation
This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property
* compute attn capability flags in normalizer instead of properties
* make attn_implementation the single source of truth
* move attention-dependent validators to mode=after
* migrate remaining consumers to canonical attn_implementation
* expand attention tests + rewrite docs
* migrate example configs to canonical attn_implementation
* update doc snippets + reject gemma4-hybrid with non-FA2 backend
* remove dead gemma4 branch in _set_attention_config
* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests
* drop "Phase 2" naming from attn-implementation tests
* regroup attn_implementation tests by feature concern
* clean up verbose comments and remove MD
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x
In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False). When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.
The old code placed `return_dict=True` inside processor_kwargs. In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning). The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:
IndexError: too many indices for tensor of dimension 2
The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped). No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.
Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
* fix(collator): process_rows returns dict (BatchFeature) shape
Two related changes for the multimodal chat collator under transformers 5.x:
1. Wrap apply_chat_template result in dict(...) so process_rows returns
a plain dict rather than a BatchFeature instance. BatchFeature is a
Mapping but not a dict; downstream code that did
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
would index on a tensor when the result wasn't dict-shaped, raising
IndexError: too many indices for tensor of dimension 2
2. Soften the regression test's contract from `dict` to `Mapping` so it
exercises the actual semantic guarantee (key/value access) rather
than the implementation detail (dict vs BatchFeature). Test guards
against the original transformers 5.x breakage where apply_chat_template's
return_dict default went from True to False.
Includes regression test under tests/test_mm_chat_collator.py.
Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.
Signed-off-by: Wing Lian <wing@axolotl.ai>
---------
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
* fix dpo collation/padding
* fix DPO collator encoder-decoder pixel_values dtype and is_encoder_decoder detection
- Use float32 instead of LongTensor for _pixel_values in encoder-decoder branch
- Add missing padding_value case for _pixel_values in encoder-decoder branch
- Derive is_encoder_decoder from model config instead of hardcoding False
* Support loss_type/loss_weights DPO
* Validate dpo loss type/weights only set for dpo
* Tests: Update ipo tests to use new path
* Docs: Update docs for new ipo path
* PR fixes - typo/validation
* PR nit - warning
* chore: fix warnings arg
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR #3598 (commit b8358aa5).
* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR #3611
* fix: swap threading.local() for module-level store so autograd worker threads see shared_kv_states during backward recompute
Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:
PR 3611's TLS variant only works when recompute runs on the same thread
that set TLS during forward. PyTorch's C++ autograd engine
(_engine_run_backward) spawns per-device worker threads to dispatch
backward, and HF-Trainer gradient_checkpointing (stock
torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
unpack_hook -> recompute_fn on those worker threads. TLS set on the main
thread during forward is invisible there, so _get_shared_kv_states()
returns None and the consumer-layer lookup crashes with
"'NoneType' object is not subscriptable" at
fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).
A plain module-level dict is visible to all threads in the process.
Lifecycle is identical: the slot is overwritten each forward, releasing
the previous step's dict and allowing its K/V tensors to be GC'd, so
the original VRAM-leak fix still holds under FSDP2 AC too.
* scope gemma4 shared_kv_states side channel to checkpointed training
Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.
Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
* add gemma4 cross-thread visibility test for shared_kv_states store
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
* fix logger
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: move to uv first
* fix: update doc to uv first
* fix: merge dev/tests into uv pyproject
* fix: update docker docs to match current config
* fix: migrate examples to readme
* fix: add llmcompressor to conflict
* feat: rec uv sync with lockfile for dev/ci
* fix: update docker docs to clarify how to use uv images
* chore: docs
* fix: use system python, no venv
* fix: set backend cpu
* fix: only set for installing pytorch step
* fix: remove unsloth kernel and installs
* fix: remove U in tests
* fix: set backend in deps too
* chore: test
* chore: comments
* fix: attempt to lock torch
* fix: workaround torch cuda and not upgraded
* fix: forgot to push
* fix: missed source
* fix: nightly upstream loralinear config
* fix: nightly phi3 long rope not work
* fix: forgot commit
* fix: test phi3 template change
* fix: no more requirements
* fix: carry over changes from new requirements to pyproject
* chore: remove lockfile per discussion
* fix: set match-runtime
* fix: remove unneeded hf hub buildtime
* fix: duplicate cache delete on nightly
* fix: torchvision being overridden
* fix: migrate to uv images
* fix: leftover from merge
* fix: simplify base readme
* fix: update assertion message to be clearer
* chore: docs
* fix: change fallback for cicd script
* fix: match against main exactly
* fix: peft 0.19.1 change
* fix: e2e test
* fix: ci
* fix: e2e test
* bump transformers to 5.5.4 and trl to latest 1.1.0
* more upgrades
* update peft too
* adapt lora_merge to peft 0.19 layer config API
PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.
* allow lora_dropout on mixed attention+MoE configs under peft 0.19
PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).
Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
* fix: rename model to adapter_model for fsdp sharded final model
* fix: follow upstream transformer shard size
* fix: handle multiple model files
* fix redundant condition, tighten to safetensors, keep shard size small
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: support excess_length_strategy for RL trainers
Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.
- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
within sequence_len while preserving the full prompt, then decodes
back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len
Closes#3547
* improve RL truncation strategy robustness and performance
---------
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.
Made-with: Cursor
* Skip redundant evaluation when resuming from checkpoint
* add condition check for adding callback
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* better handling of dora merge on Conv layers in Qwen 3.5
* address issues from code review
* stricter efficient merges for dora since we now have meta model to reference
* qwen3_5.jinja: handle list content on system messages
The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.
Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.
Fixes#3590
* Address pr for other content types
---------
Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: remove unneeded debug log
* fix: cleanup
* feat: add dense gemma config and cleanup
* feat: add cce support
* update notes and set torch compile
* fix patch for new number of return vals
* fixes for gemma4
* fix packing bug
* use updated cce for mm
* fix: pass in kv cache func when avail for transformers 5.5
* feat: update examples with flex variant and readme
* gemma4 lora attention kernels
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* upgrade to torchao 0.17.0
* upgrade mistral-common too
* chore: lint
* patch fix for torchao low bit optimizers
* fix up
* propagate dtype
* fix test for ao change
* address PR comments
* feat: add sonicmoe fused lora support
* fix: forgot to add file
* feat: add test
* feat: add lora support for other routes
* fix: add int8 lora support
* fix: add qwen35_moe interleave support
* fix: qwen3_5_moe loss
* chore: lint
* address some pr comments
* fix test imports
* add support matrix for moe kernels [skip ci]
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* docs: comprehensive documentation improvements for humans and agents
New human docs:
- grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling)
- ebft.qmd: EBFT guide (structured/strided modes, feature extraction)
- choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO
- vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync)
- training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics
New agent docs:
- AGENTS_SFT.md: agent reference for supervised fine-tuning
- AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO)
Updated existing docs:
- rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides
- getting-started.qmd: reorganized Next Steps with links to new guides
- debugging.qmd: link to training stability guide
- _quarto.yml: added new pages to sidebar navigation
Removed:
- bak.agents.md: stale backup that confused agents
* docs: trim duplicated generic config from AGENTS_DPO.md
Remove boilerplate training params (optimizer, gradient_checkpointing,
flash_attention, etc.) from each method template. These are not
preference-learning-specific and are already covered in AGENTS_SFT.md.
Config templates now show only method-specific fields with a reference
to AGENTS_SFT.md for the rest.
* docs: deduplicate across new doc pages
- grpo.qmd: collapse vLLM setup section to brief config + link to
vllm_serving.qmd; collapse IS correction to essentials + link;
replace full monitoring tables with summary + link to
training_stability.qmd
- vllm_serving.qmd: remove duplicated async/IS config reference tables
(already in grpo.qmd config reference); replace full example config
with link to grpo.qmd quick start
- ebft.qmd: trim generic training params in quick start config
* fix: train scripts
* feat: split files into cleaner parts
* fix: cleanup pretraining docs
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]
- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
argilla_chat() so datasets with tool-call messages no longer raise
KeyError: 'tool' (closes#3217)
- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
overrides content instead of tokenizer.name_or_path when added_tokens_overrides
is set, preventing cache busting when only output_dir changes (closes#3303)
- Add three Pydantic config validators to AxolotlConfigWCapabilities:
* save_strategy: 'best' requires metric_for_best_model
* streaming=True is incompatible with val_set_size > 0
* lora_target_modules list entries must be valid Python regex patterns
- Tests for all three changes
* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash
* chore: lint
* move the validators out of the w/ capabilities schema
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* qat patch
* tests fixes
* fixup per PR code review
* use state dict hooks to handle dequant for saving safetensors from transformers
* use transformers torch ao quantizer hooks to save mx quantized model
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* Add precompute_ref_log_probs to config schema
* chore: add description for config
* Add test for precompute_ref_log_probs and move to training args
* useing precompute logprobs as the default slows down CI as it has to precompute
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* allow bf16 flag but warn
Reason: when doing e.g. LoRA merges with CUDA_VISIBLE_DEVICES=, this will unnecessarily crash, even though the LoRA merge operation would have finished successfully. This seems to warrant changing it to a warning instead, as the code will most likely crash later if bf16 is unavailable and training begins anyway.
* don't use deprecated LOG.warn
* update tests to reflect validation change
* bug-fix: only apply patches when CUDA is available
This will otherwise crash when performing operations with CUDA_VISIBLE_DEVICES=, such as LoRA merging on CPU.
This patch only patches the Qwen 3.5 model, since that's the only one I've tested. This patch should most likely check torch.cuda for all other models as well. One limitation here is that I'm assuming the user runs CUDA, but that assumption is not restricted to this patch so it is probably fine.
* include patch_qwen3_next_modeling_packing, patch_qwen3_5_moe_modeling_packing, and patch_qwen3_5_vlm_flash_attention in cuda guard
* Deperecate dpo_norm_loss
* Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179
* Remove deprecated rpo_alpha
* Remove dead_code tokenize_row
* Add _tokenize override to prevent double bos token on Llama DPO
* Fix DPO loss type now list not string
* Linting fix
* PR fixes
* update _tokenize override for DPO for multimodal
* support flattening/packing for GRPO
* more flattening
* fix tests
* improve dead vllm handling
* refactor out process handling for vllm serve and move bench flattening tests to gpu tests
* add validation for flattening with liger
* isolate batch flattening test
* flaky test
* fix: handle get_open_port import across TRL versions
TRL 0.29+ removed get_open_port from exports; fall back to importing
directly from vllm.utils or vllm.utils.network_utils.
* support DP with vllm and make generation_batch_size confifurable
* nemo gym integration with grpo wip
* mostly working
* cleanup
* simplify
* update docs
* nemo gym support wip
* cleanup
* chore: lint
* address PR review and add more tests
* chore: lint
* post merge lora fixes for CI (#3536) [skip ci]
* post merge lora fixes for CI
* handle lora kernel auto-enable for moe without grouped_mm
* prefer not to import torch in schema validation
* address pr comments, add timeout, add tests
* roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions
* remove typo
* update qwen3.5 moe 35b-a3b yaml for 5090
* more bug fixes
* fix tests to match updated trainer
* don't use fa2 for hooks test
* reset plugins on the instance
* retry download
* fix references to renamed axolotl_cfg property on trainer
* Fix ref to trainer cfg
* fix: robust handling of race condition on patching check (#3543) [skip ci]
* EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip
* fixes
* more fixeS
* add missing strided module
* ebft fixes for multi-turn
* make ebft work with async
* add example for ebft w qwen3.5
* fix for split thinking and update yaml for lora over linear attention only
* enforce_eager for vllm arg in schema
* fix sync weights
* fix multi-gpu
* handle updated sig for mm
* ddp fixes
* improve multi-gpu handling, don't calculate logits, adaptive completion length
* chore: lint
* chore: lint
* support completion_mean
* Address corereview feedback
* clamp min IS ratio
* Address PR code review
* more fixes identified
* address code review
* Fix property from rebase conflict
* fix for ebft sync and update docs
* make trainer loss patch check a solo test
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* EBFT wip
* fixes
* more fixeS
* add missing strided module
* ebft fixes for multi-turn
* make ebft work with async
* add example for ebft w qwen3.5
* fix for split thinking and update yaml for lora over linear attention only
* enforce_eager for vllm arg in schema
* fix sync weights
* fix multi-gpu
* handle updated sig for mm
* ddp fixes
* improve multi-gpu handling, don't calculate logits, adaptive completion length
* chore: lint
* chore: lint
* support completion_mean
* Address corereview feedback
* clamp min IS ratio
* Address PR code review
* more fixes identified
* address code review
* Fix property from rebase conflict
* roundup_power2_divisions not needed with newer pytorch versions
* remove typo
* update qwen3.5 moe 35b-a3b yaml for 5090
* more bug fixes
* fix tests to match updated trainer
* don't use fa2 for hooks test
* reset plugins on the instance
* retry download
* fix references to renamed axolotl_cfg property on trainer
* Fix ref to trainer cfg
* feat: LoRA kernel support for bias, dropout, dora, embeddings
* chore: lint
* chore: lint
* address PR feedback, add regression tests, add fsdp2 tests for lora kernels
* update tests for new sigs
* update tests now that bias and dropout are supported
* fix token state json and mistral tokenizer issue
* centralize constants
* forgot to commit constants file
* Fix weakref in pickling relora state dict
* make curl a bit quieter so it doesn't log 2K lines
* fix path traversal for olmoe test
* more test fixes that weren't flagged previously
* chore: lint
* skip tests that fail b/c of OutOfResources
* scattermoe as slow tests
* update fbgemm-genai for torch 2.10
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:
- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py
Closesaxolotl-ai-cloud/axolotl#3331
* optimize moe + lora
* more scattermoe optims
* selective dequant
* add correctness unit tests and benchmarks for scattermoe + lora
* handle base+lora split kernel for older moe models
* chore: lint
* fix casting for H200 and B200
* register pressure estimation and pruning for h200/b200
* use soft limit for pruning
* qkv patch for qwen3.5moe
* support text_model for qwen3.5 moe
* nesting of qwen3
* use udpated cce with zero3 support
* Fix decomposed backward for QKV and O projections
eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
* use custom triton kernels for entropy from logits and selective softmax
* PR comments fixes
* fix out of bounds, include tests, include benchmarks
* chore: lint
* async grpo support
* implement data producer
* use fast async
* handle call to create data producer
* fix liger kernel setup
* fix replay buffer
* chore: lint
* make gpus go brrr
* chore: lint
* inplace div_, unwrap model for logits in bf16
* fuse selective softmax and empty cuda cache on each scoring step
* remove waiting for synch time and fix race
* make fp8 work and allow lora kernels w rl
* grpo with lora vllm sync and fixes for sharded distributed
* update docs
* more patches so it works against trl main
* address PR feedback for corerabbit
* fix: replace shell=True subprocess with argument list in modal CLI
Using shell=True with a formatted string containing docker_image
(a user-controlled value) is a command injection risk (Bandit B602).
Replace with an argument list, which passes args directly to the
process without shell interpretation, removing the nosec annotation.
* fix: add nosec annotation to suppress bandit B603/B607 warnings
Removing shell=True (B602) surfaces B603 (subprocess without shell)
and B607 (partial executable path for 'docker'). Use bare # nosec
to suppress both, consistent with other nosec usages in the codebase.
* consolidate behavioud of routing in scattermoe kernels
* collect telemetry on best chosen autotuned kernel
* properly collect data
* Fix property name and get smem too
* handle issues raised by coderabbit
* add tests for parity before refactoring
* fix: explicit set workflow permission and move secrets to necessary
steps only
* fix: comment
* fix: more permission restrict
* chore: add read for pypi
* docs: fix codestyle placeholders in CONTRIBUTING.md
Replace unresolved {codestyle} and {URLofCodestyle} template
variables with Ruff, the project's actual linter/formatter
as configured in .pre-commit-config.yaml.
* fix: replace bare except clauses with specific exception types
- quantization.py: use except ImportError for optional torchao imports
(consistent with line 48 which already uses ImportError correctly)
- cli/config.py: use except (RuntimeError, AssertionError) for CUDA
device property query
Prevents masking unrelated errors like KeyboardInterrupt or SystemExit.
* test: add unit tests for convert.py JSON/JSONL utilities
Cover FileReader, FileWriter, StdoutWriter, JsonParser,
JsonlSerializer, and JsonToJsonlConverter with 8 test cases
including roundtrip and edge case (empty list) scenarios.
Previously this module had zero test coverage.
* fix: address CodeRabbit review feedback
- quantization.py: catch (ImportError, RuntimeError) for optional
torchao imports; CUDA wheel/GPU mismatches raise RuntimeError,
not ImportError
- convert.py: remove unused output_file_path parameter from
JsonToJsonlConverter.convert() — FileWriter already holds the
output path from construction
- tests/test_convert.py: update call site to match new signature
The non-root user approach had multiple issues with RunPod
compatibility, sudo PATH handling, and tmux in exec sessions.
Restoring root as the default user for now.
* update setuptools so trl can be installed from main for nightlies
* run the nightly in the PR CI on change
* use range request, don't use cu129 in CI since it's not supported with AO
* run multigpu ci if CCE install script changes
* install flash-linear-attention
* handle prequant weights for fsdp2 and ensure loss is not zero
* fix type for cu_seqlen, uninstall causal_conv1d
* chore: lint
* uv pip uninstall doesn't need confirmation
* upgrade transformers==5.3.0 trl==0.29.0 kernels
* use latest deepspeed fixes
* use corect image for cleanup
* fix test outputs for tokenizer fixes upstream
* fix import:
* keep trl at 0.28.0
* handle updated API
* use latest trl since 0.28.0 doesn't work with latest transformers
* use trl experimental for pad to length
* monkeypatch trl with ORPOTrainer so liger doesn't croak
* upgrade accelerate
* more fixes
* move patch for orpotrainer
* load the imports later
* remove use_logits_to_keep
* fix loss_type arg as a list
* fetch hf cache from s3
* just manually download the missing model for now
* lint for pre-commit update
* a few more missing models on disk
* fix: loss_type internally now list
* fix: remove deprecated code and raise deprecate
* fix: remove unneeded blocklist
* fix: remove reliance on transformers api to find package available
* chore: refactor shim for less sideeffect
* fix: silent trl experimental warning
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* extend pytest-sdist timeout to 30 min for slow/flaky tests
* Also preload the cdn cache so it doesn't get stampeded
* fix yaml syntax
* missing fields
* can't pipe to dev/null
* Fix nightlies and add 2.10.0 to multi-gpu suite
* mxfp4 axo
* import lint
* test for qat mxfp4
* config for mxfp4
* add qat:
* pass base config
* MXFakeQuantizeConfig
* lint
* tune config so it fits in 32GB VRAM
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* Fix fsdp2 sharding. Fix validation of ao version for lr groups
* remove validation since axolotl requires ao>0.13.0 already
* Move fully_shard of entire module for lora_embedding_A/B out of loop
* chore: lint
---------
Co-authored-by: bekk02 <ID+bekk02@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* chore: rename without period
* feat: add glm45 air
* feat: add doc on expert quantization
* feat: update base readme with new changes
* chore: cleanup
* chore: cleanup
* chore: cleanup
* fix: disable quantize_moe_expert on merge per comment
* chore: add kernel info to optimizations doc
* fix: run deduplication before saving dataset during preprocessing
Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.
Fixes#2719
* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup
- Add dataset_exact_deduplication to the hash string in
generate_dataset_hash_from_config so cached datasets are invalidated
when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
dataset_exact_deduplication=True, since dedup will be silently
skipped in that configuration (both SFT and RL paths).
* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting
- Add config validator (check_deduplication_with_skip_prepare) that raises
ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues
* refactor: inline deduplicate function per review feedback
* fix test fixture, lint
---------
Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler()
* nit: raise if self.optimizer is also unset
* optimizer properly optional in create_scheduler()
* Add test cases to verify that the problem exists in the underlying
* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.
* fix: refactor and add test truncate for non-input id fields
* fix: refactor long seq handling fn
* fix: refactor duplicate fn and simplify route
* add additional tests and make them work on mac
* handle logging exception on empty datasets
---------
Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* scattermoe lora support
* fsdp, bf16, dim fixes
* expert weights aren't needed in save for bwd since they are frozen
* use sonicmoe optim options
* update save model from upstream
* fixes per code review feedback and add tests
* revert removal of CP fix
* misc fixes
* feat: support dot-notation CLI args for nested config options
Add support for overriding nested config fields (like TRL config) via
CLI using dot-notation, e.g.:
axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1
Changes:
- args.py: Detect BaseModel subclass fields and generate dot-notation
CLI options (--parent.child) that map to double-underscore kwargs
(parent__child). Also fix _strip_optional_type for Python 3.10+
union syntax (X | None).
- config.py: Handle double-underscore kwargs in load_cfg by setting
nested dict values on the config.
- Add tests for nested option handling.
Fixes#2702
* Address CodeRabbit review: fix string parent bug, add type hints and docstring
Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
* Add type coercion for CLI kwargs and fix pre-commit issues
- Add _coerce_value() for YAML-style type inference on string CLI args
- When existing config value has a type (int/float/bool), cast to match
- When no existing value, infer type from string (true/false, ints, floats, null)
- Apply coercion to both flat and nested (dot-notation) kwargs
- Fix unused pytest import (pre-commit/ruff)
- Update tests to pass string values (matching real CLI behavior)
- Add dedicated TestCoerceValue test class
Addresses maintainer feedback on type casting for nested kwargs.
---------
Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
* upgrade transformers to 5.1.0 and torchao to 0.16.0
* upgrade trl for parity
* handle trl api changes
* orpo doesn't have max_prompt_len to check anymore
* cpoconfig doesn't take max_prompt_length and fix cpu offload
* slow fsdp1 test
* triton min 3.4.0 and liger to 0.7.0
* use transformers main for now for zero3 fix
* handle group_by_length change
* fix changes upstream
* mark skip flaky test
* use transformers latest release 5.2.0
* fix: redact trackio and data_files
* fix: add new orgs to whitelist
* feat: add run id to logs for users to easily share
* fix: update to add more metrics
* fix: add missed experiment tracker
* chore: formatting in main
* feat: add sageattention
* feat: call path on pre model load
* fix: patch to use register to correct var
* fix: add strict check import at start
* chore: fix comments
* chore: refactor
* feat: add capability check
* fix: missed underscore
* fix: let sageattention use FA backend in transformers
* feat: update sage attention for attention mask and position ids
* feat: allow sample packing but add warning without packing
* fix: loss hitting 0 with packing and attention mask note
* feat: downcast embeds if sage attention too
* feat: add config validation
* feat: add attention docs
* chore: docs
* Prepare for transformers v5 upgrade
* fix hf cli
* update for hf hub changes
* fix tokenizer apply_chat_template args
* remap include_tokens_per_second
* fix tps
* handle migration for warmup
* use latest hf hub
* Fix scan -> ls
* fix import
* fix for renaming of mistral common tokenizer -> backend
* update for fixed tokenziation for llama
* Skip phi35 tests for now
* remove mistral patch fixed upstream in huggingface/transformers#41439
* use namespacing for patch
* don't rely on sdist for e2e tests for now
* run modal ci without waiting too
* Fix dep for ci
* fix imports
* Fix fp8 check
* fsdp2 fixes
* fix version handling
* update fsdp version tests for new v5 behavior
* Fail multigpu tests after 3 failures
* skip known v5 broken tests for now and cleanup
* bump deps
* unmark skipped test
* re-enable test_fsdp_qlora_prequant_packed test
* increase multigpu ci timeout
* skip broken gemma3 test
* reduce timout back to original 120min now that the hanging test is skipped
* fix for un-necessary collator for pretraining with bsz=1
* fix: safe_serialization deprecated in transformers v5 rc01 (#3318)
* torch_dtype deprecated
* load model in float32 for consistency with tests
* revert some test fixtures back
* use hf cache ls instead of scan
* don't strip fsdp_version
more fdsp_Version fixes for v5
fix version in fsdp_config
fix aliasing
fix fsdp_version check
check fsdp_version is 2 in both places
* Transformers v5 rc2 (#3347)
* bump dep
* use latest fbgemm, grab model config as part of fixture, un-skip test
* import AutoConfig
* don't need more problematic autoconfig when specifying config.json manually
* add fixtures for argilla ultrafeedback datasets
* download phi4-reasoning
* fix arg
* update tests for phi fast tokenizer changes
* use explicit model types for gemma3
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: AutoModelForVision2Seq -> AutoModelForImageTextToText
* chore: remove duplicate
* fix: attempt fix gemma3 text mode
* chore: lint
* ga release of v5
* need property setter for name_or_path for mistral tokenizer
* vllm not compatible with transformers v5
* setter for chat_template w mistral too
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
* upgrade transformers to 4.57.5
* explicitly set versions for fbgemm-gpu
* handle index url for cuda version
* explicitly set cu version for fbgemm deps, skip for 130
* cu suffix not needed on version if using whl subpath
* install xformers in the base docker image
* install numba and numpy first
* set CUDA_HOME for xformers install
* Set cuda home env
* don't install xformers by default on aarch64/arm64
* fix syntax for secrets in gha yaml
* setup env for uv too
* arm64 for base uv too
* don't build causal-conv1d or mamba for arm64 and use arm64 wheels
* fix dockerfile syntax
* fix shell syntax
* upgrade dependencies
* don't use reset sessions
* downgrade transformers, upgrade other deps
* upgrade bnb to 0.49.0
* restore s3 cache
* explicit use local files w hub
* decompress and strip top level dir
* use 2 levels for strip components
* try to preserve permissions for symlinks
* use updated tar
* fix#3293 for distributed
* downgrade bnb
* fast fail after 4
* fix total tokens device
* patch accelerate CP/SP (#3309)
---------
Co-authored-by: salman <salman.mohammadi@outlook.com>
* build examples readmes with quarto
* chore: formatting
* feat: dynamic build docs
* feat: add more model guides
* chore: format
* fix: collapse sidebar completely to have space for model guides
* fix: security protection for generated qmd
* fix: adjust collapse level, add new models, update links
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* compute loss only if training
* save total_tokens for checkpiont
* check if string
* refactor total_tokens/ num_tokens
* refactor 2
* rplc trainable_step/trian_per_sec_per_gpu
* lint + log trainable/tokens
* consolidate it in the callback.
* test for total_tokes aftr remuse
* check if tokenstate exist after ckpt
---------
Co-authored-by: Ved <ved.work2024@gmail.com>
* feature: raise on long sequence drop
It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.
* tests: add excess_length_strategy tests
* doc: updated return value description for drop_long_seq_in_dataset
* add @enable_hf_offline
* fixed cfg modified after validate_config called
* hf offline fix
* fix tqdm desc when raise is used
* test: added test for non-batched case
* accidental code change revert
* test: use pytest.raises
* test: simplified drop_seq_len tests
* test: moved excess_length_strat test to test_data.py
---------
Co-authored-by: salman <salman.mohammadi@outlook.com>
* METRIC_PRECISION-> 8
* use ndigits and move env getter to top of log function
---------
Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* When training of function calls, "tools" elements of a dataset can contain same parameter name but with different types. Datasets fails to load such training set. This fix allows "parameters" element of function call to be string( by running "json.dumps" in preparation of training data set). The _get_tools function will iterate over tool definitions, if "parameters" element is dict, it will keep that way, if it is a string, it will be converted to dict by invoking "json.loads" on string value.
* feat: add doc on tool parameters json loading
* feat: add tests for parameters json string
---------
Co-authored-by: ezlotnik <eduard_zlotnik@intuit.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* upgrade numpy to 2.3.4
* bump contribs for numpy
* fix vllm versions
* bump numba
* make sure psutil is installed
* add psutil to cicd dockerfile jinja
* lower dep versions of numba + numpy for vllm
* bump datasets version
* resolve pydantic conflict too
* build cuda 13.0.0 base image with 2.9.0
* upgrade causal-conv1d
* 1.5.4 not in pypi yet
* pin to 1.3.0
* use github release instead of pypi
* split the logic for incompatible packages
* fix bash in dockerfile
* fix: force train split for json,csv,txt for test_datasets
* feat(doc): add info on mixing datasets for VLM
* feat(doc): max memory
* fix(doc): clarify lr groups
* fix: add info on vision not being dropped
* feat: add qwen3-vl to multimodal docs
* fix: add moe blocks to arch list
* feat(doc): improve mistral docs
* chore: add helpful link [skip-e2e]
* fix: add vram usage for mistral small
* Update link in docs/faq.qmd
Co-authored-by: salman <salman.mohammadi@outlook.com>
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
* Fix trainer dataloader handling in src/axolotl/core/trainers/base.py
* update comment to reflect torch version
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* Add chat_template.argilla_chat support for DPO datasets
Creates a new chat_template.argilla_chat prompt strategy for handling
DPO datasets where chosen/rejected fields contain full conversations
(messages + final response), following the pattern of chatml.argilla_chat
and llama3.argilla_chat.
- Add argilla_chat() function to chat_template.py
- Add chat_template.argilla_chat to RLHF documentation
- Add test coverage for argilla_chat with multiple tokenizers
Dataset format:
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
* Fix chat_template.argilla_chat return value contract and add docstring
- Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn
- Add remove_columns specification for field_chosen and field_rejected
- Add comprehensive docstring with Args/Returns sections
- Update tests to unpack tuple return value
Addresses PR feedback to maintain consistency with chat_template.default()
and properly specify columns to remove after dataset transformation.
* Update tests/prompt_strategies/test_dpo_chat_templates.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* fix: transformers deprecate load_in_Xbit in model_kwargs
* fix: test to read from quantization_config kwarg
* fix: test
* fix: access
* fix: test weirdly entering incorrect config
- Fix _loss_function attribute not found on base model with PEFT
- Fix mismatched attribute name (loss_function vs _loss_function)
- Set _loss_function on unwrapped base model for PEFT
- Enable previously skipped test_llama_lora_kd test
- Add test config fixes for LoRA kernel compatibility
Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
* make sure to use ray prepare for dataloader fixes
* ray tests use 2.7.0+
* don't call init_distributed w ray and deepspeed
* handle dict deepspeed config
* better handling of dict deepspeed config
* use json.dumps
* guard to_dict
* wrap import for optional ray
* upgrade transformers to 4.57.0
* remove deprecated autoawq and use latest peft
* remove autoawq from setuptools script
* fix imports
* make sure torchvision is installed
* remove support for BetterTransformer
* skip fsdp_qlora_prequant test
* more robust error reporting
* pass max_prompt_len to training args as max_prompt_length
* Update rl.py
* refactor
* format
* fix: default for max_prompt_length
* fix: defaults for trainer
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* feat: add hunyuan cce support
* feat: update cce docs
* feat: add multipack support for granite and hunyuan
* feat: add hunyuan docs and example config
* feat: update readme instructions to include CCE installation
* fix: chat template log appearing despite tokenizer already having template
* feat: add vram usage
* fix: remove duplicate cce install
* fix: use latest commit of PR in case rebased/pushed
* Revert "fix: use latest commit of PR in case rebased/pushed"
This reverts commit 8b60aa00de.
* feat: update doc as upstream merged
* default true
* force e2e
* causal trainer only
* fix eval loggin [skip-ci]
* revert setup.py
* force tests
* guarding
* guarding
* fix test case
* use evaluate [skip-e2e]
* use evaluate [skip-e2e]
* kick off ci
* fixing
* reverting
* feat: upgrade transformers to v4.56
* fix handling of CP/SP now that position_ids are default even for unpacked sequences
* feat: monkeypatch list_repo_templates
* fix: apply patch for tests only
* see if updated main works at least
* fix: update to patch release and remove monkeypatch
* remove fsdp2 eval patch
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: add center_rewards_coefficient for reward modeling
- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows
Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.
* Update description
* test: add unit tests for center_rewards_coefficient integration
* Update src/axolotl/core/builders/base.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* Update docs/reward_modelling.qmd
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* Update docs/reward_modelling.qmd
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* reference to TRL documentation.
* add new reward model configuration for qwen3 with comprehensive parameters
* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.
* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient
* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.
* linting
* nit
* Apply suggestions from code review
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
* lint
---------
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
* improve fsdp shard merging
* improve logging
* update information on merging and inferencing GPT-OSS
* cleanup readme
* automate cleanup of FSDP prefix
* import GRPO only if necessary
* only modify config.json on rank0
* merge final checkpoint at end of training
* prevent circular import
* Fix saving for sharded state dict
* devx, move merged to output dir
* move import back to top
* Fix stuck merge
* fix conditionals from pr feedback and add test
* fix to not use batch feature indexing
* more vlm fixes
* use AutoModelForImageTextToText
* add example yaml and need num2words for chat template
* improve handling of adding image tokens to conversation
* add lfm2-vl support
* update the lfm readme
* fix markdown and add rtol for loss checks
* feat: add smolvlm2 processing strat
* fix: check for causal-conv1d in lfm models
* feat: add docs for lfm2
* feat: add new models and tips to docs
* feat: add smolvlm2 docs and remove extra dep
* chore: update docs
* feat: add video instructions
* chore: cleanup
* chore: comments
* fix: typo
* feat: add usage stats
* chore: refactor
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* use exec instead of subprocess to make ctrl+c nicer for cli
* change var name to use_exec
* simplify to bool
* flush std*
* patch subprocess as mock in test
* fix tests
* more test fixes
* feat(doc): add links to new features on README
* fix merge error
* remove blurb about older FSDP2 integration
* update blog link
* chore: update cce commit
* feat: update model support into readme
* Update README.md
Co-authored-by: salman <salman.mohammadi@outlook.com>
* chore: lint num spaces
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
* use nanmena for loss aggregation (CP fix)
* use regular asserts
* small changes to make tests isolate
* combining evaluation_loop patches
* fix
* delete unused
* fix check
* slurm example and make preprocess play nicely
* start slurm if it init file exists
* remove incorrect comment
* feat: add slurm docs
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* fix for parallelism config from trainer
* fix handling of parallelism_config w accelerate
* add todo for removal
* update to latest axolotl-contribs-mit for optimizer fix too
* synchronize training after checkpoint save
* dir spelling
* use latest accelerate main
* fix to not use partial state parallelism_config
* more fixeS
* use most recent accelerate fix
* fix cpu_ram_efficient_loading to meta devices from rank 0 to prevent CPU RAM oom
* improve handling of broadcasting fsdp2 state dict
* support for openai chat template with thinking key as the reasoning trace
* address PR feedback
* refactor to remove dependency on PartialState for parallelism config
* bump accelerate, gptoss fixes
* limit meta fixes to fsdp2 for now
* fixes for gpt oss
* fixup examples, don't use cpu-ram-efficient-loading for now
* remove problematic barrier
* patch parallelism config
* reorder comparison
* device mesh fixes
* make pure CP work
* lint
* add kernels for gpt oss models
* add support for gpt-oss
* typo incorrect package
* fix: layout for configs and added wandb/epochs
* add gptoss example w offload and set moe leaf for z3
* add support for Mxfp4Config from yaml
* update yaml to use official model
* fix lora and don't allow triton to go above 3.3.1
* fix lr and tweak vram use
* fix range for triton since pinned wasn't compatible with toch 2.6.0
* update cce with gpt oss patches
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* Add support for Dion optimizer
* dion training kwargs
* fix var names
* no dion 8bit for now
* use updated axolotl-contribs-mit for dion optimizer
* add smoke test for dion optimizer
* add docs
* fix typo during edits
* fix test to not remove load in 8bit
* fix: deepcopy lr in RexLR scheduler.
This fixes a problem where when the lr is a scalar tensor, the base_lrs in the get_lr function end up being references to the current learning rate, rather than the correct initial learning rate.
See also related pytorch PR https://github.com/pytorch/pytorch/pull/127190/
* fix: add missing torch.Tensor import
* jagged lr restart scheudler
var name fix
make sure to create scheduler first
* wire things together
* more fixes
* fix for nesting scheduler and first anneal phase
* no need for relora trainer anymore since we've generalized the relora scheduler
* remove redundant relora scheduler and lint
* update relora e2e test for updated params
* need restart steps for relora test
* update quarto docs for dropped relora trainer
* update example yaml
* drop verbose arg
* min lr scale support for jagged lr
* don't let min_lr be nonetype
* cleanup args
* feat(doc): add vastai link
* feat: add cloud providers to readme for more visibility
* add prime intellect, remove Modal as sponsor
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* make TiledMLP work with FSDP
* cleanup/gc at start of train to prevent large VRAM spike
* chore: lint
* generic function for non-deepspeed training
* unify patch to fix imports
* update readme for ALST and add examples
* make deepspeed attribute on params check more robust
* update with new info from PR review
* we don't need to call check_dataset_labels when skip_prepare_dataset is set
* Fix actual bug and revert prior fix
* warn and early return instead of raising an error
* use error
* Revert "fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg…"
This reverts commit e207762928.
* don't revert the values that don't use 'auto'
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* static autocomplete script for axolotl cli
* use list of commands that should autocomplete yaml files
* make sure to chmod the autocomplete script as executable
* shellcheck and fix autocompletion of directory/sub-dirs
* more shellcheck fixes
* feat: add gemma3n cce
* feat: add sample config
* feat: add gemma3n multimodal mode
* feat: add audio example
* feat: support audio and return pixel values in collator
* feat: support unmask only assistant region (gemma3n for now)
* feat(doc): add notes for audio loading
* feat: add audio support for gemma3n
* feat: update examples
* feat: add gemma3n to the docs
* fix: add link at top
* feat(doc): clarify additional requirements
* fix: mllama missing aspect ratio
* fix: mllama need attention fixes for fa2
* Partially Revert "fix: mllama need attention fixes for fa2"
This reverts commit a0bfdd1777.
* fix: disable FA2 for mllama in vision mode
* feat: update configs to use proper attention
* fix: support other vision features
* feat(doc): clarify requirements for gemma3n
* make pad_to_sequence_len default to the same value as sample_packing
* remove duplicate validation
* fix test
* update description meta
Co-authored-by: NanoCode012 <nano@axolotl.ai>
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* limit num_proc when saving datasets to disk
* enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save
* update fixtures with dataset processes since that should never be NoneType
* improve reusability for tests
* make the initial call to tokenizer.pad not spam the console
* add guard from feedback
* make another common console output less verbose
* more logging fixes
* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer
* removed limiters and fixed config variable name
* black lint
* chore: lint
* feat: update handling of dataset_processes
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* Apply generic fused liger ce for unknown models
* fix deepseek liger modeling
* generic cce and config tiled mlp to use original mlp and auto detect compute params
* fix weight and lint
* update warnings
* address PR feedback
* use lookup for model class prefixes
* revert inadvertent change to flash attn verison
* remove un-needed pylint annotations
* fix import
* checkpoint model on first step callback
* remove debug
* add test cases; update existing tests not to save on first step
* move test out of solo
* delete
* default to False
* typo
* support for deepspeed autotup
* bump to latest deepspeed that supports deepcompile too
* add deepcompile support too
* fix total steps calculation for TP
* setup fixture for tp
* update ds config to ensure weights are gathered for checkpoint
* fix duplicate validation names
* chore: lint
* use cuda streams for activation offloading
* use torch native ops
* update cfg schema for streams
* fix literal constructor for set
* use context for training step so it doesn't affect evals
* disable streams
* auto gc on eval steps
* use activation_offloading config arg
* add docs for gradient checkpointing
* handle validation for gc/ao
* use cuda streams for act offloading
* add more validation for AC w/o GC
* fix docs
* move activation_offloading lower in definition so it doesn't break args/kwargs
* fix kd due to import order
* upgrade peft to 0.16.0
* upgrade datasets to 4.0.0
* refactor dupes from merge/rebase
* fix check for fsdp1 + sharded_state_dict
* use full state dict for ci
* upgrade trl==0.19.1
* add vllm for tests for grpo
* fixes to work with latest trl
* need data_parallel_size config too
* support for vllm_mode for server / colocate
* vllm settings for colocate
* relax vllm version
* bump min hf hub for latest vllm support
* add hints on string literal for vllm mode
* use latest transformers 4.53.2
* tweak acceptable loss on flaky test_ds_zero3_packed test
* don't run flaky vllm/grpo tests for now
* FSDP2 args migration implementation
This commit implements the migration to FSDP2 arguments including:
- FSDP2 support with LoRA training
- DPO integration with FSDP2
- Model loading fixes and refactoring
- CPU offloading and PEFT handling
- Test updates and CI improvements
- Bug fixes for dtype errors and various edge cases
* tiled_mlp supports single gpu
* use checkpoint offloading for arctic training
* patch torch checkpoint too
* support for single gpu zero3
* add linkback to where it was copied from
* fix: do not add training and training_detail block by default
* fixed: magistral docs
* fix: address pad adding new fields and use built-in from_openai
* feat: try enable multiprocessing
* fix: check for keys before deleting attn_mask
* feat: add mistral pad test
* feat: add tool calling test
* feat: add devstral tokenizer tests
* fix: comma format
* chore: remove unused support_preprocessing as tokenizer is pickable now
* chore: update magistral doc
* feat: add devstral readme and example
* chore: refactor error handling
* densemixer plugin integration
* update readme with usage docs
* automatically find new integrations that aren't explicitly defined
* make sure to import os
* update transformers to 4.53.0
* remove attention_mask from signature columns if using packing
* remove attention_mask column from dataloader
* update signature of flash attn forward for ring attn patch
* fix FSDP
* patch ring-flash-attn with upstream signature fix
* fix patch indentation level
* fix the patch
* add batch flattening smoke test with loss check that works in older transformers
* fix patch
* don't drop attention mask for flex
* more fixes
* patch create_causal_mask for packing w flex
* global torch manual_seed fixture
* tweak loss checks
* fix patch and use single batch for flex
* don't need to reload
* fix causal mask patch
* use transformers patch releasE
* make sure env var is string
* make sure to drop attention mask for flex w packing for latest transformers patch release
* tweak loss
* guard on signature columns before removing attention mask
* bump loss
* set remove isn't chainable
* skip slow mistral test in 2.5.1
* fix: let users know to not call preprocess for vision mode
* fix: improve ux for pretraining dataset and skip prepare ds
* feat: add info to doc
* Update src/axolotl/cli/preprocess.py following comment
Co-authored-by: salman <salman.mohammadi@outlook.com>
---------
Co-authored-by: salman <salman.mohammadi@outlook.com>
* respect shuffle_merged_datasets for single dataset too
* update inline comment for behavior
Co-authored-by: NanoCode012 <nano@axolotl.ai>
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* feat: update handling for mistraltokenizer decode
* fix: update mistral common package version
* fix: to use correct release
* fix triton path
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* upgrade to flash-attn 2.8.0.post2
* use cu126 with torch 2.6
* seems vllm 0.8.5.post1 not compatible with cuda12.6.3 and torch 2.6
* cu126 + torch 2.6 as the default
* use cu126 for multigpu w torch 2.6 too
* drop vllm for now from ci for now
* ignore generation/endgeneration tags
Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error.
* feat: add phi4 tokenizer test and unblock gemma2
* fix: improve template
* chore: refactor
* chore: lint
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* kd fixes
* fix collator setup
* fix input args
* better handling to drop string fields for kd with raw dataset
* kd trainer has kd temp as part of the init
* drop top_k before softmax
* simplfy and remove zscore
* WIP chunked KD loss with autograd wrapper
* more fixes and liger-type chunked loss
* collator cls for plugins
* remove debugging
* additional plugin collator kwargs, don't scale up kd loss by t^2
* don't need temp arg to distill method
* online kd wip
* add close to comment block
* suport sampling params/max new tokens
* handle when no custom collator is used in plugins
* logsumexp trick:
* fix check
* shift off the first empty token
* fix length of padding
* use max not min
* temp scale kd loss at end
* support for dynamic plugin training args mixins and symmetric kl
* chore: lint
* fix trainer callback base class
* Fix decay
* accept compressed responses for smaller wire payload
* post-rebase lint
* more KD updates
* increase hyperparams_count for gradients for added normalize_topk
* fix to remove attention_mask
* rename vars for consistency
* fix rebase issues
* default to dropping last batch in multipack batch sampler
* improve handling of train len
* init collator_cls_and_kwargs
* explicit drop_last=False when checking for multipack completeness
* use separate v2 loader for kd
* fix kd tests to use subprocess so it picks up kd training args
* default value for kd_beta arg
* use updated dataset for ci
* longer timeout for e2e
* fix: do not pre-patch self attention if lora dropout non-zero
* fix: add test to check patch not applied
* fix: test
* fix: test config check
* fix where we check so that tests don't break
* fix: test
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: add fsdp config for magistral
* fix: add mllama self attention handling for lora kernels
* fix: no eval if val_set_size 0 despite having test_datasets
* fix: add note for cce for vlm in newer model
* build base images for torch 2.7.1
* fix: update base docker to use torch 2.7.1
* fix: update doc for main base to use 2.7.1
* make sure to install fa2 in base uv too
* use no build isolation for uv+flashattn
* install psutil also for fa2
* longer timeout for flash attn build
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* Update batching.py: fix the bug of position ids padding
if position ids is padded with a long sequence of zeros, it will cause flash attention to crash
* use alternate calculation for padding position_ids with a range
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* add uv tooling for e2e gpu tests
* fixes from PR feedback
* simplify check
* fix env var
* make sure to use uv for other install
* use raw_dockerfile_image
* Fix import
* fix args to experimental dockerfile image call
* use updated modal versions
* remove unused field for chat_template.default
"messages" field present in final dataset causes issues with DPO
training otherwise
* lint and fix tests for new return value
* remove unused field for chat_template.default
"messages" field present in final dataset causes issues with DPO
training otherwise
lint and fix tests for new return value
fix for updated expected fields for dpo
remove unused field for chat_template.default
"messages" field present in final dataset causes issues with DPO
training otherwise
fix test still expecting "messages" field
* chore: lint
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* bump hf deps
* upgrade liger-kernel too
* install cce from fork for transformers fix
* fix reference to vocab size in gemma3 patch
* use padding_idx instead of pad_token_id
* remove fixed gemma3 patch
* use updated cce fork
* fix local mllama cce patches w docstring
* add test for multipack with trainer setup and fix trainer for trainer refactor upstream
* bump modal version
* guard for iterable datasetS
* mllama model arch layout changed in latest transformers
* fix batch sampler with drop_last
* fix: address upstream vlm changes for lora
* fix: update references to old lora target path
* fix: remove mllama fa2 patch due to upstream fix
* fix: lora kernel patch path for multimodal models
* fix: removed mllama from quarto
* run test for came optim on 2.6.0+
* fix fsdp2 patch and remove deprecated patch
* make sure to set sequence_parallel_degree for grpo
* Add SP test for GRPO
* add sp to grpo config for trainer
* use reward_funcs as kwarg to grpo trainer
* fix the comprehension for reward funcs
* reward funcs already passed in as args
* init sp_group right before training
* fix check for adding models to SP context
* make sure to pass args to super
* upgrade deepspeed
* use updated trl and add reasoning flags for vllm
* patch the worker
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* chore: update pre-commit hooks
* trigger linter when pre commit hooks are updated
* fix type checks from upgraded pre-commit
---------
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: increase log level for root loggers and axolotl's
* fix: BasePlugin using wrong logger
* fix: update logger to take name from module
* feat: change logger class to AxolotlLogger to filter non-axolotl infos or below
* fix: change behavior to not disable existing loggers
* fix: update logging to respect correct env
* chore: fix comment
* fix: suppress accelerate log to LOG_LEVEL if not set
---------
Co-authored-by: salman <salman.mohammadi@outlook.com>
* feat: add num_proc and load from cache for rl mapping
* fix: refactor sft and rl trainer to set same base args
* feat: add report_to to set run name
* fix: consolidate handling of fp16, bf16, tf32 kwarg
* chore: consolidate eval_strat, loraplus, lr sched, max_length
* fix: deprecate old types
* fix: adding missing Any
* fix: max_steps incorrectly set
* fix: remove unnecessary datacollator kwarg insert and pop
* fix: update default max_steps
* fix: add missing weight_decay handling
* fix: ignore max_length for grpo
* feat: update CI on trainer_builder
* fix: comments
* improve handling of warmup/logging steps
* use transformers default for logging steps, not None
* fix: remove redundant override
* fix: lint
* feat: allow custom optim for rl methods
* fix: duplicate optim setting
* fix(test): set sequence_parallel_degree default in base cfg
* feat: add handling for seed and SP/ring-attn config
* chore: add back return typing from rebase
* fix(test): use RLType directly to skip needing to validate
* feat: split training builder into sub modules
* fix: remove deprecated clause
* chore: add missing config to doc
* fix: update quarto autodoc
* fix: import path for trainer builder and submodules
* fix: remove redundant configs from rebase mistake
* chore: simplify dynamo check
* fix: optimizer_cls_and_kwargs to be passed into trainer_kwargs
* fix: add missing rex from rebase
* fix: move pop optimizer_cls_and_kwargs
* fix: pop optimizer cls in rl too
* fix: leftover bug from rebase
* fix: update handling of trainer_cls in RL
* fix: address pr feedback
* feat: call hook_pre_create_trainer for rl
* chore: lint
* fix: return notimplemented for ppo
* feat: moved torch compile to base and refactor collator setting
* chore: remove unused importlib.util import
* fix: optimizer cls not being popped
* feat: move epoch setting to base
* fix: catch unhandled custom optimizer
* fix: remove duplicate lora plus setting
* chore: refactor if condition
* chore: refactor set_base_training_args into smaller modules
* fix: address TrainerBuilderBase class variables to instance var
* fix: add handling for beta3 and episilon2
* fix: change to pass dict via arg instead of updating dict
* chore: simplify if condition
* fix: force access to lr & weight decay in case not provided to early error
* fix: remove log sweep
* chore: refactor if condition
* fix: address renamed cfg
* fix: improve handling of cosine hyp
* fix: remove unused params
* chore: refactor
* chore: clarify doc safetensors
* fix: update import path to be unified following comments
* fix: duplicate kwargs passed
* feat: return separate trainer_kwargs
* chore: refactor
* chore: refactor based on comments
* chore: refactor based on comments
* fix: move gpustats callback to base
* chore: create trainer_cls_args first based on comments
* fix: ipo label smoothing passed incorrectly
* feat: add optimizer parity for RL methods with test
* feat: add parity for optimizer in RM/PRM and add test
* fix: remove redundant function override for orpo/cpo batch metrics
* fix: improve handling of dpo_label_smoothing and merge issue
* fix: test fixture returning wrong field
* fix: address avoid direct modify fixture
* chore: minor refactor
* Revert "chore: refactor"
This reverts commit 99c8859eb0.
* feat: rename trainer_builder to builders
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat(doc): add info on how to use dapo / dr grpo
* chore: add missing config to docs
* fix: missing comment
* fix: add missing scheduler from schema
* chore: refactor lr scheduler docs
* fix: remove log_sweep
* add two checks to handle legacy format interleaved ds
* fix: add warning about multiple image using legacy format
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* don't set peft_config on grpo to prevent double peft wrap
* remove overrides needed to support bug
* fix grpo tests
* require more CPU for multigpu to help with torch compile for vllm
* make setting `adam_beta3` and `adam_epsilon2` work correctly
* update config docs so users know args are specific to CAME optim
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* offload activations to disk instead of CPU RAM
* add prefetch
* Disco :dance:
* include offload_disk in e2e test for AC
* document and make sure to cleanup
* fix annotation to match docs
* fix docs build
* address PR feedback
* update doc and skip brittle grpo test
* fix the path to run the multigpu tests
* increase timeout, use LOC instead of NVL
* typo
* use hf cache from s3 backed cloudfront
* mark grpo as flaky test dues to vllm start
* lean mistral ft tests, remove e2e torch 2.4.1 test
* make sure to pass save_only_model for RL
* more tests to make ci leaner, add cleanup to modal ci
* fix module for import in e2e tests
* use mp spawn to prevent deadlocks with packing
* make sure cleanup shell script is executable when cloned out
* fsdp embeddings should be float32 per comment
* patch peft to not upcast everything
* add tabs back to code check
* fix import
* add configurable option and fix check
* add check for dtypes
* move embeddings test to patch dir
* fix test
* fix comment and logic
* improve readability of multipack sampler
* parallel bin packing
fix error with lambda and pickling
make sure things are in float instead of np.float
* annotations and comments update
* support for configurable group and bin size for sample packing
* fix missing map back to original indices
* feat(doc): add split_thinking docs
* fix: link config.qmd to conversation.qmd for split_thinking example
* update thinking => reasoning_content in messages format
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* repop cache
* pre-cache as a step
* fix the name
* add reason for pytest skipif
* restore pytorch matrix
* remove max-parallel now that we've optimized this a bit
* Adds example for training a TTS model on top of a LLM.
* Update examples/orpheus/finetune.yml
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* Update examples/orpheus/finetune.yml
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* Update README.md to clarify GPU requirements for finetuning Orpheus TTS model
* Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained
* Update finetune.yml and README.md for consistency and clarity
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* only configure logging on cli to play nicely with colab
* allow reloading the config on the fly from a dict
* make sure to use dict for yaml
* reuse existing function for load
* make cli args optional
* mps fix and respect max_steps
* Add: SFTPlugin with llmcompressor
* Update: review comments!
* Add:llmcompressor instalable
* pre commit hooks
* Use: warning over warn
* Revert: TODO's
* Update llmcompressor version to latest
* Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
* Address review comments from @markurtz
* Add: llcompressor installable
* Rename: sft.yaml to sparse-finetuning.yaml
* Use: absolute import
* Update model config
* Move: LLMCompressorPlugin into it's own submodule
* Add: `llm_compressor` integration documentation
* Rebase and updates!
* Tests, Style, Updates
* Add: .qmd file
* Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
* Add: line about further optimizations using llmcompressor
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
* Apply patch from @winglian
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
* Fix: Test
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
* additional fixes for docker and saving compressed
* split llmcompressor from vllm checks
* Reset session between tests
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
* move decorator to test method instead of class
* make sure to reset the session after each test
* move import of llmcompressor to reset session inside test
---------
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing
* fix: comments
* chore: add some examples of tokens
* feat: add new potential errors for chat_template to faq
* feat: add examples for EOT handling
* fix: change error to warning for missing EOS
* fix: warning typo
* feat: add tests for eot token handling
* fix: remove broken caplog capture in test
* fix: chattemplate strategy with kd missing eot changes
* Add runpod sls handler
* remove LICENSE and fix README
* chore: lint
* use axolotl cloud image as base and various fixes
* fix: trim allowed cuda versions
* restore dockerfile
* chore: update title
* use axolotl cloud image
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* update trl to 0.17.0
* grpo + vllm no longer supported with 2.5.1 due to vllm constraints
* disable VLLM_USE_V1 for ci
* imporve handle killing off of multiprocessing vllm service
* debug why this doesn't run in CI
* increase vllm wait time
* increase timeout to 5min
* upgrade to vllm 0.8.4
* dump out the vllm log for debugging
* use debug logging
* increase vllm start timeout
* use NVL instead
* disable torch compile cache
* revert some commented checks now that grpo tests are fixed
* increase vllm timeoout back to 5min
* add e2e smoke test for using activation/gradient checkpointing with offload
* disable duplicate code check for the test
* fix relative import
* seq len too small to test this dataset with packing
* Fix checkpoint ptaching for tests
* make sure to validate the config before normalizing so defaults get set
* validation not needed for particular test
* remove duplicate validations
* set qlora correctly
* fix: mention to install pytorch before axolotl
* feat(doc): include instruction to delinearize
* fix: update instruction for delinearize with adapter
* builds for torch==2.7.0
* use xformers==0.0.29.post3
* no vllm support with torch 2.7
* update default, fix conditional
* no xformers for 270
* no vllm on 2.7.0 for multigpu test too
* remove deprecated verbose arg from scheduler
* 2.7.0 tests on cpu
* batch api HF adapter for ring-flash-attn; cleanup and improvements
* update
* adding all batch ring-flash-attn methods via single adapter
* removing pad_to_sequence_len=False for now
* fix
* updating docs to include batch SP
* review comments
* fixes for batch API funcs, simplify
* fixes
* fix
* updates
* add batch_zigzag smoke test
* fixes for delinearization, and make qlora work with fsdp2
* Add back mistakenly removed lm_eval
* typo [skip ci]
* patch evals for torch.compile + fsdp2
* also check torch_compile w fsdp2
* lots of fixes for flex attn with llama4
* fix patch check and patch llama4 too
* attempt to make the patches stick
* use transformers 4.51.2
* update configs and README for llama4
* remove torch.compile for CI test
* cleanup any existing singletons
* set singleton cache to None instead of deleting
* use importlib reload with monkeypatch
* don't worry about transformers version, mark inputs with grads, fix regex
* make sure embeds aren't on cpu
* logging and mem improvements
* vllm version and add to docker, make sure to save processor on conversion
* fix ambiguous tensor bool check
* fix vllm to not use v1, upgrade hf transformers
* fix tests
* make flex_attn_compile_kwargs configurable, since this depends on model params
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
* [ci] make e2e tests a bit faster by reducing test split size
* use 10% split of alpaca dataset to speed up dataset loading/tokenization
* reduce gas 4->2 for most e2e tests
* increase val set size for packing
* feat: add llama4 multimodal
* feat: add torchvision to base docker
* just use latest torchvision
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* llama4 support
* add xet support [skip ci]
* be flexible on transformers version and skip test on version
* don't use deepspeed for the fix_untrained_tokens test
* reordering to trigger torch 2.6.0 tests first
* slightly smaller train set
* use 4.51.0 for now
* remove stray print, add llama4 chat template to schema, bump peft to 0.15.1
* patches to make llama4 performant
* add preliminary fp8 support
* fsdp2 support
* use accelerate release 1.6.0
* allow 8bit optims with fsdp2
* liger + torch compile fix
* add fsdp2 e2e tests
* use transformers commit with fsdp2 support
* skip zero3 tests for this PR for now
* fix fsdp2 config for ci
* make sure both flex and flash attn work with fsdp2, skip fix untrained tokens
* okay, actually use fdsp2...
* more fixes to flex for fsdp2
* make sure to patch all the loaded models
* additional validation for fsdp2, bump dep versions
* make torch 2.6.0 the default image
* fix tests against upstream main
* fix attribute access
* use fixture dataset
* fix dataset load
* correct the fixtures + tests
* more fixtures
* add accidentally removed shakespeare fixture
* fix conversion from unittest to pytest class
* nightly main ci caches
* build 12.6.3 cuda base image
* override for fix from huggingface/transformers#37162
* address PR feedback
* make gemma3 work with packing
* multi-gpu e2e for ci
* update gemma3 model namespace to use mirror
* add gradient checkpointing to multigpu e2e ci
* update gemma3 examples for use_reentrant and fix ddp find unused params
* fix tests for gemma3
* fix import for test utils
* set correct train loss for gemma3 e2e
* fix: clarify input type
* fix: handling of error message if data_files not available
* fix: clarify attention handling
* fix: add doc on missing pad token
* add grpo scale_rewards config for trl#3135
* options to connect to vllm server directly w grpo trl#3094
* temperature support trl#3029
* sampling/generation kwargs for grpo trl#2989
* make vllm_enable_prefix_caching a config param trl#2900
* grpo multi-step optimizeations trl#2899
* remove overrides for grpo trainer
* bump trl to 0.16.0
* add cli to start vllm-serve via trl
* call the python module directly
* update to use vllm with 2.6.0 too now and call trl vllm serve from module
* vllm 0.8.1
* use python3
* use sys.executable
* remove context and wait for start
* fixes to make it actually work
* fixes so the grpo tests pass with new vllm paradigm
* explicit host/port and check in start vllm
* make sure that vllm doesn't hang by setting quiet so outouts go to dev null
* also bump bnb to latest release
* add option for wait from cli and nccl debugging for ci
* grpo + vllm test on separate devices for now
* make sure grpo + vllm tests runs single worker since pynccl comms would conflict
* fix cli
* remove wait and add caching for argilla dataset
* refactoring configs
* chore: lint
* add vllm config
* fixup vllm grpo args
* fix one more incorrect schema/config path
* fix another vlllm reference and increase timeout
* make the tests run a bit faster
* change mbsz back so it is correct for grpo
* another change mbsz back so it is correct for grpo
* fixing cli args
* nits
* adding docs
* docs
* include tensor parallel size for vllm in pydantic schema
* moving start_vllm, more docs
* limit output len for grpo vllm
* vllm enable_prefix_caching isn't a bool cli arg
* fix env ordering in tests and also use pid check when looking for vllm
---------
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
* guard return if ring attn alrady registered
* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)
* configurable heads_k_stride from ring-flash-attn hf adapter
* fix: update chat_template
* fix: handle gemma3 showing a lot of no content for turn 0
* fix: remove unknown config from examples
* fix: test
* fix: temporary disable gemma2 test
* fix: stop overwriting config.text_config unnecessarily
* fix: handling of set cache to the text_config section
* feat: add liger gemma support and bump liger to 0.5.5
* fix: add double use_cache setting
* fix: add support for final_logit_softcap in CCE for gemma2/3
* fix: set use_cache before model load
* feat: add missing layernorm override
* fix: handle gemma3 rmsnorm
* fix: use wrapper to pass dim as hidden_size
* fix: change dim to positional
* fix: patch with wrong mlp
* chore: refactor use_cache handling
* fix import issues
* fix tests.e2e.utils import
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* hf offline decorator for tests to workaround rate limits
* fail quicker so we can see logs
* try new cache name
* limit files downloaded
* phi mini predownload
* offline decorator for phi tokenizer
* handle meta llama 8b offline too
* make sure to return fixtures if they are wrapped too
* more fixes
* more things offline
* more offline things
* fix the env var
* fix the model name
* handle gemma also
* force reload of modules to recheck offline status
* prefetch mistral too
* use reset_sessions so hub picks up offline mode
* more fixes
* rename so it doesn't seem like a context manager
* fix backoff
* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline
* include additional dataset
* more fixes
* more fixes
* replace tiny shakespeaere dataset
* skip some tests for now
* use more robust check using snapshot download to determine if a dataset name is on the hub
* typo for skip reason
* use local_files_only
* more fixtures
* remove local only
* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached
* make sure fixtures aren't offline
improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache
* now run all the ci again with hopefully a correct cache
* skip wonky tests for now
* skip wonky tests for now
* handle offline mode for model card creation
* add 12.8.1 cuda to the base matrix
* use nightly
* bump deepspeed and set no binary
* deepspeed binary fixes hopefully
* install deepspeed by itself
* multiline fix
* make sure ninja is installed
* try with reversion of packaging/setuptools/wheel install
* use license instead of license-file
* try rolling back packaging and setuptools versions
* comment out license for validation for now
* make sure packaging version is consistent
* more parity across tests and docker images for packaging/setuptools
* use default torch fused adamw optimizer as default as adamw_hf is deprecated
* make sure to have latest packaging installed
* bump packagingin requirements.txt too
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install
# test
@@ -57,11 +61,23 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
#### Skipping CI Checks
You can skip certain CI checks by including specific keywords in your commit messages:
-`[skip ci]` or `skip ci` - Skips all CI checks for that commit
-`[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
## Style Guidelines
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
pre-commit run --all-files
```
### Commit Messages
@@ -71,6 +87,6 @@ Write clear and concise commit messages that briefly describe the changes made i
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
github:[winglian, OpenAccess-AI-Collective]# Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github:# Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon:# Replace with a single Patreon username
open_collective:# Replace with a single Open Collective username
ko_fi:axolotl_ai# Replace with a single Ko-fi username
ko_fi:# Replace with a single Ko-fi username
tidelift:# Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge:# Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay:# Replace with a single Liberapay username
issuehunt:# Replace with a single IssueHunt username
otechie:# Replace with a single Otechie username
lfx_crowdfunding:# Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom:['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480¢erImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom:# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
## 🎉 Latest Updates
- 2026/04:
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
- 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
- 2026/01:
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
- 2025/12:
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
<details>
<summary>Expand older updates</summary>
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
</details>
## ✨ Overview
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Features:
-Train various Huggingface models such as llama, pythia, falcon, mpt
-Supports fullfinetune, lora, qlora, relora, and gptq
-Customize configurations using a simple yaml file or CLI overwrite
-Load different dataset formats, use custom formats, or bring your own tokenized datasets
-Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
-Works with single GPU or multiple GPUs via FSDP or Deepspeed
-Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
-**Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
-**Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
-**Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
-**Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
-**Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
-**Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
-**Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start
## 🚀 Quick Start - LLM Fine-tuning in Minutes
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.4.1
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
[](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/getting-started.html) for a more detailed walkthrough.
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
## ✨ Key Features
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
## 📚 Documentation
- [Installation Options](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) - Supported formats and how to use them
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
- Need dedicated support? Please contact [✉️wing@axolotl.ai](mailto:wing@axolotl.ai) for options
## 🌟 Contributing
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
::: {.callout-note}
LoRA is not supported for BitNet models
:::
## Installation
Install the `onebitllms` package before using this feature:
Online RL with verifiable reward functions. For full config reference, async features, and scaling, see [grpo.qmd](../grpo.qmd). For vLLM setup, see [vllm_serving.qmd](../vllm_serving.qmd).
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels:true
use_scattermoe:true
experts_implementation:scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules:true# Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant:false# Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1.**Does the model load?**`axolotl preprocess config.yaml` — catches config schema errors
2.**Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3.**Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
4.**Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5.**Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
print(f"Direct loss: {out.loss.item()}")# Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py``__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py``patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
-`gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
-`experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
Reference for DPO, IPO, KTO, ORPO, and SimPO. For config templates and dataset format examples, see [rlhf.qmd](../rlhf.qmd). For GRPO, see [grpo.qmd](../grpo.qmd). For EBFT, see [ebft.qmd](../ebft.qmd).
## Method Overview
| Method | Data Requirement | Key Idea | Best For |
Method-specific: DPO/IPO watch `rewards/margins`; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation.
## Known Issues
| Issue | Fix |
|-------|-----|
| Sample packing crash | Set `sample_packing: false` (required for all preference methods) |
| KTO `KeyError: 'label'` | Ensure dataset has boolean `label` column |
| ORPO/KTO `KeyError` during tokenization | Add `remove_unused_columns: false` |
Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see [getting-started.qmd](../getting-started.qmd) and [dataset-formats/](../dataset-formats/).
## Architecture
```
YAML Config → axolotl train config.yaml
1. Load base model (+ quantization if QLoRA/8-bit)
2. Apply adapter layers (LoRA/QLoRA) if configured
| `learning_rate` | Follows scheduler curve | Flat or NaN — config issue |
Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See [training_stability.qmd](../training_stability.qmd).
## Known Issues
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
## File-Based Checkpoint Trigger
### Configuration
Enable in your config:
```yaml
dynamic_checkpoint:
enabled: true
check_interval: 100 # Optional: check every N steps (default: 100)
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
**Custom filename:**
```yaml
dynamic_checkpoint:
enabled: true
trigger_file_path: "my_trigger.save"
```
```bash
touch /path/to/output_dir/my_trigger.save
```
## Control+C (SIGINT) Checkpoint
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
## Best Practices
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
description: "A decision guide for choosing the right fine-tuning method, adapter, and hardware configuration in Axolotl."
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
## Overview {#sec-overview}
Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.
| Method | What It Does | Data You Need |
|--------|-------------|---------------|
| **Supervised Fine-Tuning (SFT)** | Teaches the model to produce specific outputs given inputs | Input-output pairs (instructions, conversations, completions) |
| **Preference Learning (DPO/KTO/ORPO)** | Steers the model toward preferred outputs and away from dispreferred ones | Chosen/rejected response pairs (DPO, ORPO) or binary labels (KTO) |
| **Reinforcement Learning (GRPO)** | Optimizes the model against a reward signal through online generation | A reward function (code or model-based) and a prompt dataset |
| **Reward Modeling** | Trains a model to score responses, for use as a reward signal in RL | Preference pairs ranked by quality |
Each method is configured through a YAML file with `rl: <method>` (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.
## Decision Tree {#sec-decision-tree}
Use the following flowchart to choose your method. Start at the top and follow the path that matches your situation.
```
Do you have a reward function (code-based or model-based)?
├── YES
│ └── Use GRPO (rl: grpo)
│ The model generates its own completions and learns from reward scores.
│ Best for: math, code, reasoning, tasks with verifiable answers.
│ See: rlhf.qmd#grpo
│
└── NO
│
Do you have preference pairs (chosen vs. rejected responses)?
├── YES
│ │
│ Are they paired (same prompt, one chosen, one rejected)?
│ ├── YES → Use DPO (rl: dpo)
│ │ Direct optimization without a separate reward model.
│ │ See: rlhf.qmd#dpo
│ │
│ └── NO (only binary good/bad labels)
│ └── Use KTO (rl: kto)
│ Works with unpaired preference data.
│ See: rlhf.qmd#kto
│
└── NO
│
Do you have input-output examples?
├── YES → Use SFT
│ The simplest and most common method.
│ See: getting-started.qmd
│
└── NO
└── You need to create training data first.
Consider generating preference pairs with an LLM judge,
or writing a reward function for GRPO.
```
::: {.callout-tip}
**When in doubt, start with SFT.** It is the most straightforward method and works well for most tasks. You can always move to preference learning or RL later to further refine behavior.
:::
### Method Comparison at a Glance
| Criterion | SFT | DPO | KTO | GRPO |
|-----------|-----|-----|-----|------|
| Data complexity | Low (input-output pairs) | Medium (preference pairs) | Medium (binary labels) | Low (prompts + reward code) |
| Compute cost | Low | Medium | Medium | High (requires vLLM server) |
| Reward model needed | No | No | No | No (uses reward functions) |
| Best for | Task adaptation, instruction following | Safety, style alignment | Unpaired preference data | Reasoning, math, code |
::: {.callout-note}
**ORPO** is an alternative to DPO that combines SFT and preference optimization in a single training stage, removing the need for a separate SFT step. Configure with `rl: orpo`. See [rlhf.qmd](rlhf.qmd) for details.
:::
## Adapter Selection {#sec-adapter-selection}
Once you have chosen a method, decide how to apply the parameter updates. The three main options trade off VRAM usage against model quality.
### QLoRA
- **How it works**: The base model is loaded in 4-bit (NF4) quantization. Small low-rank adapter matrices are trained in higher precision on top.
- **VRAM savings**: Roughly 4x reduction in model memory compared to full fine-tuning.
- **Quality**: Slight degradation due to quantization noise, but often negligible for task-specific fine-tuning.
- **When to use**: When your GPU cannot fit the model in full precision, or when you want fast experimentation.
```yaml
adapter: qlora
load_in_4bit: true
lora_r: 32
lora_alpha: 64
lora_target_linear: true
```
### LoRA
- **How it works**: The base model is loaded at full precision (or 8-bit). Low-rank adapter matrices are trained alongside.
- **VRAM savings**: Roughly 2-3x reduction compared to full fine-tuning (model weights are frozen, only adapters + optimizer states for adapters are stored).
- **Quality**: Very close to full fine-tuning for most tasks, especially with higher rank values.
- **When to use**: When you have enough VRAM for the base model but not for full optimizer states.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
```
::: {.callout-tip}
For GRPO training, LoRA is strongly recommended. The vLLM server needs to sync weights from the trainer, and LoRA sync (`trl.vllm_lora_sync: true`) is far more efficient than syncing full merged weights. See [vLLM Serving](vllm_serving.qmd) for details.
:::
### Full Fine-Tuning
- **How it works**: All model parameters are updated during training. No adapters.
- **VRAM savings**: None. Requires memory for model weights, gradients, and optimizer states (roughly 4x model size in bf16 with AdamW).
- **Quality**: Highest potential quality, especially for large distribution shifts.
- **When to use**: When you have ample GPU memory or multi-GPU setups, and need maximum performance. Also required for pre-training.
```yaml
# No adapter or load_in_* lines needed
micro_batch_size: 1
gradient_accumulation_steps: 16
```
### Quick Comparison
| | QLoRA | LoRA | Full |
|---|---|---|---|
| Trainable params | ~0.1-1% | ~0.1-1% | 100% |
| Model memory | ~25% of full | ~50-100% of full | 100% |
These estimates assume a short context length (512-2048 tokens) and micro_batch_size of 1-2. Longer sequences and larger batches increase memory significantly due to activations. Use [gradient checkpointing](gradient_checkpointing.qmd) to reduce activation memory at the cost of ~30% slower training.
:::
### GRPO (RL Training)
GRPO requires additional GPU(s) for the vLLM generation server. Plan for at least two GPUs: one for training, one for vLLM.
| Model Size | Training GPU (LoRA, bf16) | vLLM GPU | Total GPUs |
For single-GPU GRPO, use `vllm_mode: colocate` with `vllm_enable_sleep_mode: true`. The vLLM engine shares the GPU and offloads VRAM when not generating. This works for smaller models (up to ~3B on a 24 GB GPU) but is slower than the two-GPU server mode.
:::
### Multi-GPU Threshold
You need multi-GPU training when:
- **Full fine-tuning** of models 7B+ (use FSDP or DeepSpeed ZeRO)
- **LoRA** of models 30B+ (or 13B+ with long contexts)
- **GRPO** almost always (separate vLLM server), unless using colocate mode
See [Multi-GPU Training](multi-gpu.qmd) for FSDP and DeepSpeed configuration.
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -197,6 +220,29 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
# Use bitsandbytes 4 bit
load_in_4bit:
# Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
# Custom user instruction prompt
- path: repo
type:
# The below are defaults. only set what's needed if you use a different column name.
system_prompt: ""
system_format: "{system}"
field_system: system
field_instruction: instruction
field_input: input
field_output: output
# Customizable to be single line or multi-line
# Use {instruction}/{input} as key to be replaced
# 'format' can include {input}
format: |-
User: {instruction} {input}
Assistant:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
chat_template_jinja:
# Key containing the messages (default: "messages")
field_messages: messages
# Mapping of properties from the input dataset to the chat template.
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant"]
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
- path: /workspace/data/eval.jsonl
ds_type: json
# You need to specify a split. For "json" datasets the default split is called "train".
split: train
type: completion
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# reward modelling: `True` or `False`
reward_model:
# process reward modelling: `True` or `False`
process_reward_model:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# Keep dataset in memory while preprocessing
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # private repo path to push finetuned model
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Tensorboard
use_tensorboard: # Optional[bool]
# Where to save the full-finetuned model to
output_dir: ./completed-model
# Whether to use torch.compile and which backend to use
# setting to `auto` will enable torch compile when torch>=2.5.1
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
micro_batch_size: 2
eval_batch_size:
num_epochs: 4
warmup_steps: 100 # cannot use with warmup_ratio
warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
# snapshots can be visualized @ https://pytorch.org/memory_viz
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# Group similarly sized data to minimize padding.
# May be slower to start, as it must download and sort the entire dataset.
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: true
# Stop training after this many evaluation losses have increased in a row
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
See [configs](../config.qmd) for full configs and supported templates.
See [configs](../config-reference.qmd) for full configs and supported templates.
### Migrating from sharegpt
@@ -64,7 +52,9 @@ We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
#### Training on last message
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -78,7 +68,9 @@ datasets:
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
#### Overriding default chat template
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
@@ -88,7 +80,13 @@ datasets:
roles_to_train: ["assistant"] # default value
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
::: {.callout-note}
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
:::
#### Using default chat template with fallback
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
@@ -97,7 +95,9 @@ datasets:
type: chat_template
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
#### Custom Jinja template
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
@@ -108,11 +108,150 @@ datasets:
type: chat_template
```
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
::: {.callout-tip}
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
```yaml
chat_template_jinja: ./path/to/my_template.jinja2
```
:::
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
#### Using template with different token for EOT and EOS
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
- "[/INST]"
# - "[/SYSTEM_PROMPT]"
datasets:
- path: ...
type: chat_template
# optional
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
```
::: {.callout-tip}
See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
:::
::: {.callout-note}
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.
:::
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
- "[/INST]"
# ...
datasets:
- path: ...
type: chat_template
train_on_eos: last
train_on_eot: turn
```
::: {.callout-tip}
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
:::
#### Using tool use
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
```json
{
"tools": [
{
"type": "...",
"function": {
"name": "...",
"description": "...",
"parameters": {
"type": "...",
"properties": {
// ...
},
"required": ["..."],
},
},
},
],
"messages": [
// ...
{
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"id": "...", // required only for mistral
"type": "function",
"function": {
"name": "...",
"arguments": {
"...": "...",
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "...", // required only for mistral
"name": "...",
"content": "..."
},
],
}
```
::: {.callout-note}
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
::: {.callout-warning}
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:
```yaml
chat_template: llama4
datasets:
- path: Nanobit/text-tools-2k-test
type: chat_template
# field_tools: tools # default is `tools`
```
::: {.callout-tip}
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
:::
#### Using fine-grained control over token masking
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -162,3 +301,152 @@ datasets:
::: {.callout-tip}
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Content parts with per-part training control
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think step by step...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
}
]
}
```
The configuration is the same as standard `chat_template` — no extra fields needed:
```yaml
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
Each content part supports:
- `type`: `"text"` (required)
- `text`: the text value (also accepts `content` or `value` as the key)
- `train`: `true`/`false` (optional) — whether to train on this part
- `weight`: `0`/`1` (optional) — alternative to `train`
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
::: {.callout-warning title="Whitespace at part boundaries"}
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
**Split BEFORE spaces** (space goes with the next part):
```json
[
{"type": "text", "text": "Let me think...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
```
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
```json
[
{"type": "text", "text": "Let me think... ", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
:::
::: {.callout-note}
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
:::
##### Per-part training on reasoning_content
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": true}
]
}
]
}
```
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
::: {.callout-tip}
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
:::
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:
- path: ...
type: chat_template
chat_template: qwen3
split_thinking: true
```
For example, a content can look like:
```json
{
"content": "<think>Some thinking outputs</think>Output after thinking."
}
```
After split, it will look like:
```json
{
"reasoning_content": "Some thinking outputs",
"content": "Output after thinking..."
}
```
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
@@ -13,92 +13,55 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
::: {.callout-tip}
This guide will mainly use JSONL as an introduction. Please refer to the [dataset loading docs](../dataset_loading.qmd) to understand how to load datasets from other sources.
For `pretraining_dataset:` specifically, please refer to the [Pre-training section](#pre-training).
:::
## Pre-training
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
A sample format for a pre-training dataset is as follows:
Pre-training trains on raw text corpora with no input masking. The dataset format is simple:
```json
{"text": "first row"}
{"text": "second row"}
...
```
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
Axolotl supports two approaches:
Axolotl supports loading from a Hugging Face hub repo or from local files.
### Streaming (large datasets)
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
```yaml
pretraining_dataset: hf_org/name
```
### Pre-training from local dataset files
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
For large corpora that don't fit in memory, use `pretraining_dataset` with [streaming](../streaming.qmd). Data is tokenized on-demand during training.
```yaml
pretraining_dataset:
- path: json
data_files:
- A.jsonl
- B.jsonl
- C.jsonl
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
```
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
::: {.callout-important}
Streaming requires `max_steps` in your config — Axolotl cannot infer the dataset size. One step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
:::
### Pre-training without streaming
See [Streaming Datasets](../streaming.qmd) for full configuration details.
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
### Non-streaming (smaller datasets)
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
From Hugging Face:
For datasets that fit in memory, use `type: completion` under `datasets:`. The entire dataset is pre-tokenized before training, which can be done on a CPU-only machine.
```yaml
datasets:
- path: hf_org/name
- path: my_corpus
type: completion
```
From local files (either example works):
```yaml
datasets:
- path: A.jsonl
type: completion
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
type: completion
```
### Pre-training dataset configuration tips
#### Setting max_steps
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
#### Group_by_length
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
### Reference
Please see docs [here](pretraining.qmd).
::: {.callout-note}
With `completion`, texts exceeding `sequence_len` are split into multiple samples automatically.
:::
## Supervised fine-tuning (SFT)
@@ -450,10 +413,7 @@ datasets:
type: alpaca
```
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
Axolotl supports many kinds of instruction dataset. All of them can be found in the [Instruction Dataset Documentation](inst_tune.qmd) with their respective type and sample row format.
description: Understanding how to load datasets from different sources
back-to-top-navigation: true
toc: true
toc-depth: 5
---
## Overview
Datasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.
## Loading Datasets
We use the `datasets` library to load datasets and a mix of `load_dataset` and `load_from_disk` to load them.
You may recognize the similar named configs between `load_dataset` and the `datasets` section of the config file.
```yaml
datasets:
- path:
name:
data_files:
split:
revision:
trust_remote_code:
```
::: {.callout-tip}
Do not feel overwhelmed by the number of options here. A lot of them are optional. In fact, the most common config to use would be `path` and sometimes `data_files`.
:::
This matches the API of [`datasets.load_dataset`](https://github.com/huggingface/datasets/blob/0b5998ac62f08e358f8dcc17ec6e2f2a5e9450b6/src/datasets/load.py#L1838-L1858), so if you're familiar with that, you will feel right at home.
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
For full details on the config, see [config-reference.qmd](config-reference.qmd).
::: {.callout-note}
You can set multiple datasets in the config file by more than one entry under `datasets`.
```yaml
datasets:
- path: /path/to/your/dataset
- path: /path/to/your/other/dataset
```
:::
### Local dataset
#### Files
To load a JSON file, you would do something like this:
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
This works for CSV, JSON, Parquet, and Arrow files.
::: {.callout-tip}
If `path` points to a file and `ds_type` is not specified, we will automatically infer the dataset type from the file extension, so you could omit `ds_type` if you'd like.
:::
#### Directory
If you're loading a directory, you can point the `path` to the directory.
Then, you have two options:
##### Loading entire directory
You do not need any additional configs.
We will attempt to load in the following order:
- datasets saved with `datasets.save_to_disk`
- loading entire directory of files (such as with parquet/arrow files)
```yaml
datasets:
- path: /path/to/your/directory
```
##### Loading specific files in directory
Provide `data_files` with a list of files to load.
```yaml
datasets:
# single file
- path: /path/to/your/directory
ds_type: csv
data_files: file1.csv
# multiple files
- path: /path/to/your/directory
ds_type: json
data_files:
- file1.jsonl
- file2.jsonl
# multiple files for parquet
- path: /path/to/your/directory
ds_type: parquet
data_files:
- file1.parquet
- file2.parquet
```
### HuggingFace Hub
The method you use to load the dataset depends on how the dataset was created, whether a folder was uploaded directly or a HuggingFace Dataset was pushed.
::: {.callout-note}
If you're using a private dataset, you will need to enable the `hf_use_auth_token` flag in the root-level of the config file.
:::
#### Folder uploaded
This would mean that the dataset is a single file or file(s) uploaded to the Hub.
```yaml
datasets:
- path: org/dataset-name
data_files:
- file1.jsonl
- file2.jsonl
```
#### HuggingFace Dataset
This means that the dataset is created as a HuggingFace Dataset and pushed to the Hub via `datasets.push_to_hub`.
```yaml
datasets:
- path: org/dataset-name
```
::: {.callout-note}
There are some other configs which may be required like `name`, `split`, `revision`, `trust_remote_code`, etc depending on the dataset.
:::
### Remote Filesystems
Via the `storage_options` config under `load_dataset`, you can load datasets from remote filesystems like S3, GCS, Azure, and OCI.
::: {.callout-warning}
This is currently experimental. Please let us know if you run into any issues!
:::
The only difference between the providers is that you need to prepend the path with the respective protocols.
```yaml
datasets:
# Single file
- path: s3://bucket-name/path/to/your/file.jsonl
# Directory
- path: s3://bucket-name/path/to/your/directory
```
For directory, we load via `load_from_disk`.
#### S3
Prepend the path with `s3://`.
The credentials are pulled in the following order:
- `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` environment variables
- from the `~/.aws/credentials` file
- for nodes on EC2, the IAM metadata provider
::: {.callout-note}
We assume you have credentials setup and not using anonymous access. If you want to use anonymous access, let us know! We may have to open a config option for this.
:::
Other environment variables that can be set can be found in [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables)
#### GCS
Prepend the path with `gs://` or `gcs://`.
The credentials are loaded in the following order:
- gcloud credentials
- for nodes on GCP, the google metadata service
- anonymous access
#### Azure
##### Gen 1
Prepend the path with `adl://`.
Ensure you have the following environment variables set:
- `AZURE_STORAGE_TENANT_ID`
- `AZURE_STORAGE_CLIENT_ID`
- `AZURE_STORAGE_CLIENT_SECRET`
##### Gen 2
Prepend the path with `abfs://` or `az://`.
Ensure you have the following environment variables set:
- `AZURE_STORAGE_ACCOUNT_NAME`
- `AZURE_STORAGE_ACCOUNT_KEY`
Other environment variables that can be set can be found in [adlfs docs](https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials)
#### OCI
Prepend the path with `oci://`.
It would attempt to read in the following order:
- `OCIFS_IAM_TYPE`, `OCIFS_CONFIG_LOCATION`, and `OCIFS_CONFIG_PROFILE` environment variables
- when on OCI resource, resource principal
Other environment variables:
- `OCI_REGION_METADATA`
Please see the [ocifs docs](https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables).
### HTTPS
The path should start with `https://`.
```yaml
datasets:
- path: https://path/to/your/dataset/file.jsonl
```
This must be publically accessible.
## Next steps
Now that you know how to load datasets, you can learn more on how to load your specific dataset format into your target output format [dataset formats docs](dataset-formats).
@@ -6,6 +6,10 @@ description: How to debug Axolotl
This document provides some tips and tricks for debugging Axolotl. It also provides an example configuration for debugging with VSCode. A good debugging setup is essential to understanding how Axolotl code works behind the scenes.
::: {.callout-tip}
For training-specific debugging (loss spikes, NaN gradients, OOM errors, RL training stability), see [Training Stability & Debugging](training_stability.qmd).
:::
## Table of Contents
- [General Tips](#general-tips)
@@ -29,7 +33,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
@@ -72,8 +76,10 @@ datasets:
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
#### Remote Hosts
@@ -85,7 +91,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotltrain dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```json
// .vscode/launch.json
@@ -101,7 +107,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:
You will now be in the container. Next, install Axolotl with dev dependencies:
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^1]: The VSCode config uses `accelerate.commands.launch` as the Python module entry point, which is what `axolotl train` invokes under the hood.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
This section describes the different Docker images that are released by AxolotlAI at
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
### Switch to the `-uv` images
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
description: "Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions."
order: 9
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the **internal feature representations** of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
### How EBFT Differs from Other RL Methods
| Method | Reward Signal | Requires | Best For |
|--------|--------------|----------|----------|
| **GRPO** | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
| **DPO** | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
| **EBFT** | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
EBFT's key advantage is that it needs only ground-truth completions -- no reward engineering, no preference annotation, and no reward model training. The model's own internal representations serve as the reward signal. This makes it particularly effective for:
- Code generation (match features of known-good solutions)
- Instruction following with reference outputs
- Continual pretraining on unstructured text (strided mode)
- Multi-turn dialogue with reference conversations
- **Alignment**: How closely the generated output's internal representations match the ground truth. Higher is better.
- **Diversity**: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
- **CFM loss** (Cross-Feature Matching): Tracks `||mean(gen_features) - gt_features||^2` as a diagnostic. This is the quantity that EBFT ultimately minimizes.
## Modes
EBFT supports three operational modes, each suited to different use cases.
### Structured Mode (Sync)
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
```
GPU 0: vLLM Server (generates completions, receives weight syncs)
**When to use**: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
### Structured Mode (Async)
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
**When to use**: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by `vllm_sync_interval`).
### Strided Mode
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
**When to use**: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
## Quick Start
### Structured Mode
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
**Step 1**: Create a config file `ebft_quickstart.yaml`:
```yaml
base_model: Qwen/Qwen2-0.5B-Instruct
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
alignment_coef: 1.0
diversity_coef: 1.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 3
use_data_producer: true
async_prefetch: false
scale_rewards: true
loss_type: grpo
vllm:
gpu_memory_utilization: 0.5
max_model_len: 1024
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
# Standard training settings (see getting-started.qmd for details)
The `micro_batch_size` must be divisible by `num_generations`. For example, with `num_generations: 4`, valid values are 4, 8, 12, etc.
:::
### Dataset Format
Structured mode datasets must produce two fields after the transform:
- `prompt`: Either a string or a list of chat messages (`[{"role": "user", "content": "..."}]`)
- `ground_truth`: A string containing the reference completion
Example raw dataset row:
```json
{
"input": "Write a function to compute fibonacci numbers.",
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
}
```
The `ebft_opencode.transform` converts this to the required `{prompt, ground_truth}` format automatically.
## Feature Extraction
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
### Feature Layers
The `feature_layers` parameter specifies which layers to extract, as fractions of total model depth:
For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size `num_layers * hidden_dim`.
::: {.callout-tip}
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default `[0.25, 0.5, 0.75]` works well across model sizes.
:::
### Embed Methods
The `embed_method` controls how per-token hidden states are pooled into a single vector per sequence:
| Method | Description | Output Shape | Notes |
|--------|-------------|-------------|-------|
| `last_token` | Hidden state at the last non-padding token | `(B, D)` | Default. Good for autoregressive models where the last token summarizes the sequence. |
| `mean_pooling` | Mean of all non-padding token states | `(B, D)` | Considers the entire sequence equally. |
| `completion_mean` | Mean over completion tokens only (excludes prompt) | `(B, D)` | Focuses reward signal on generated content. Requires prompt length information. |
| `concat` | Concatenation of states at 25%, 50%, 75% positions | `(B, 3*D)` | Captures positional structure. Higher dimensional. |
```yaml
ebft:
embed_method: completion_mean # Focus on completion features
```
### SVD Whitening
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
```yaml
ebft:
use_whitening: true
```
When whitening is enabled, the reward computation applies a whitening matrix `W = U @ diag(1/S) @ U^T` derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
::: {.callout-note}
Singular values scale with `sqrt(batch_size)`, so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (`n_samples_per_prompt` or `num_generations`) is fixed during training.
:::
### Alignment and Diversity Coefficients
The two reward components are weighted by coefficients:
```yaml
ebft:
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
diversity_coef: 1.0 # Weight for pairwise similarity penalty
```
Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
Setting `diversity_coef: 0.0` disables the diversity penalty entirely, which may be appropriate when `num_generations` is small (e.g., 2).
## Strided Mode
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places **anchor points** at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
### How Block-Parallel Generation Works
Given a document of length `S` tokens:
1. **Anchor placement**: Starting at position `anchor_offset`, place anchors every `stride` tokens. Each anchor defines a block.
2. **Context window**: Each block sees `context_length` tokens of preceding context from the original document.
3. **Generation**: At each anchor, generate `generate_max_len` tokens autoregressively, conditioned only on the context window.
4. **Parallelism**: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
```
Document: [tok0, tok1, ..., tok_S]
| | |
anchor_0 anchor_1 anchor_2
| | |
[ctx][gen] [ctx][gen] [ctx][gen]
```
The attention mask ensures:
- Prompt tokens use standard causal attention
- Each generated block attends to its own context window and its own preceding generated tokens
- Blocks do not attend to each other's generated tokens
When `flex_attention` is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
### Strided Mode Configuration
```yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided
stride: 8 # Tokens between anchor points
context_length: 8 # Context window per block
generate_max_len: 8 # Tokens to generate per block
n_samples_per_prompt: 4 # Independent rollouts per document
temperature: 0.6
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0 # RL policy gradient loss weight
ce_coef: 0.03 # Cross-entropy loss on GT tokens
advantage_estimator: rloo # rloo, group_norm, or reinforce
min_completion_prefix: 8 # Skip anchors in prompt region
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
bf16: auto
attn_implementation: flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention
```
Run with a single command (no vLLM needed):
```bash
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
### Advantage Estimators
Strided mode supports three advantage estimation methods:
| `reinforce` | Raw reward as advantage (no baseline) | Works with `n_samples_per_prompt = 1` |
::: {.callout-warning}
When `n_samples_per_prompt: 1`, the trainer automatically falls back to `reinforce` and disables the diversity penalty (which requires multiple samples).
:::
### Strided Mode Constraints
- **`flex_attention: true`** is strongly recommended. Without it, dense 4D masks consume significantly more memory.
- **`torch_compile: true`** must NOT be set. `flex_attention` compiles its own kernels internally; adding `torch_compile` causes conflicts and OOM.
- **Gradient checkpointing** must use `use_reentrant: true`. Non-reentrant checkpointing causes `CheckpointError` with `flex_attention` block masks.
- **`activation_offloading`** is incompatible with `flex_attention`.
### Cross-Entropy Loss
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
```yaml
ebft:
ce_coef: 0.03 # Small CE coefficient
rl_coef: 1.0 # RL loss coefficient
```
The total loss is `rl_coef * rl_loss + ce_coef * ce_loss`. For structured mode, `ce_coef` is typically `0.0` since vLLM generation provides sufficient learning signal.
## Dataset Formats
EBFT provides several built-in dataset transforms in `src/axolotl/prompt_strategies/ebft/`.
### Built-In Transforms
| Transform | Input Format | Output Fields | Use Case |
For structured (sync/async) mode, the transform must produce `prompt` and `ground_truth` fields:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
```
### Multi-Turn Datasets
Multi-turn transforms extract conversation data for sequential rollout. The `transform` variant targets the first assistant turn, while `transform_last_turn` targets the final turn:
```yaml
datasets:
- path: your/multiturn-dataset
type: ebft_chat_multiturn.transform
```
When `remaining_turns` is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
### Strided Mode Datasets
Strided transforms tokenize the full document and produce `input_ids`, `labels`, and `prompt_length`:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
```
### Custom Transforms
To use your own dataset format, write a transform function:
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent `<think>` tags in completions:
```yaml
trl:
chat_template_kwargs:
enable_thinking: false
```
## Monitoring
### Key Metrics
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
| Metric | Healthy Range | Interpretation |
|--------|--------------|----------------|
| `ebft/alignment` | 0.3 -- 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
| `ebft/diversity` | 0.01 -- 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
| `ebft/cfm_loss` | Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
| `ebft/reward` | Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
| `IS ratio min` | Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase `vllm_sync_interval`. |
### Console Log Example
During training, you will see periodic EBFT reward logs:
```
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
```
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
### Troubleshooting
| Symptom | Likely Cause | Fix |
|---------|-------------|-----|
| `alignment` stays below 0.1 | Feature layers not capturing useful information | Try different `feature_layers` or `embed_method` |
| `diversity` exceeds 1.0 | Mode collapse -- generations are too similar | Increase `diversity_coef` or `temperature` |
| `reward` stuck at -1.0 | Diversity penalty dominates alignment | Reduce `diversity_coef` or increase `alignment_coef` |
| `grad_norm` consistently 0.0 | All micro-batches have zero advantage | Increase `num_generations` or check data quality |
| `CheckpointError` in strided mode | Incompatible gradient checkpointing settings | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
| OOM during training | Logits tensor too large | Reduce `sequence_len` or `micro_batch_size`; strided mode uses chunked lm_head to mitigate this |
| vLLM 500 errors | `truncate_prompt_tokens` not supported | Ensure you are using `axolotl vllm-serve` (not `trl vllm-serve`) |
### Feature Network Memory
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the `disable_adapter()` context manager. This saves an entire model copy in VRAM (approximately 1--16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
::: {.callout-note}
The `disable_adapter()` approach relies on an invariant: `merge_adapter()` is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
:::
## Examples
Complete example configurations are available in `examples/ebft/`:
| Config | Model | Mode | Description |
|--------|-------|------|-------------|
| `llama-1b-ebft-strided-structured.yaml` | Llama 3.2 1B | Strided | Single-GPU strided training on code data |
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
> ```yaml
> special_tokens:
> # str. If you're not sure, set to same as `eos_token`.
> pad_token: "..."
> ```
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
**Q: Can we mix text and text+image datasets for VLM training?**
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
> A: There can be two reasons:
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
> A: There can be two reasons:
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
**Q: `EOT token __ is encoded as multiple tokens.`**
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
**Q: If `eot_tokens: ` is not provided, what happens?**
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
**Q: `Data processing error: CAS service error`**
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
> A: Depending on the version of torch, you may need to include this in your YAML:
> ```yaml
> flex_attn_compile_kwargs:
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
**Q: `Error parsing tool_calls arguments as JSON.`
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
Some files were not shown because too many files have changed in this diff
Show More
Reference in New Issue
Block a user
Blocking a user prevents them from interacting with repositories, such as opening or commenting on pull requests or issues. Learn more about blocking a user.