* fix dpo collation/padding
* fix DPO collator encoder-decoder pixel_values dtype and is_encoder_decoder detection
- Use float32 instead of LongTensor for _pixel_values in encoder-decoder branch
- Add missing padding_value case for _pixel_values in encoder-decoder branch
- Derive is_encoder_decoder from model config instead of hardcoding False
* Support loss_type/loss_weights DPO
* Validate dpo loss type/weights only set for dpo
* Tests: Update ipo tests to use new path
* Docs: Update docs for new ipo path
* PR fixes - typo/validation
* PR nit - warning
* chore: fix warnings arg
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR #3598 (commit b8358aa5).
* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR #3611
* fix: swap threading.local() for module-level store so autograd worker threads see shared_kv_states during backward recompute
Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:
PR 3611's TLS variant only works when recompute runs on the same thread
that set TLS during forward. PyTorch's C++ autograd engine
(_engine_run_backward) spawns per-device worker threads to dispatch
backward, and HF-Trainer gradient_checkpointing (stock
torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
unpack_hook -> recompute_fn on those worker threads. TLS set on the main
thread during forward is invisible there, so _get_shared_kv_states()
returns None and the consumer-layer lookup crashes with
"'NoneType' object is not subscriptable" at
fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).
A plain module-level dict is visible to all threads in the process.
Lifecycle is identical: the slot is overwritten each forward, releasing
the previous step's dict and allowing its K/V tensors to be GC'd, so
the original VRAM-leak fix still holds under FSDP2 AC too.
* scope gemma4 shared_kv_states side channel to checkpointed training
Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.
Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
* add gemma4 cross-thread visibility test for shared_kv_states store
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
* fix logger
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: move to uv first
* fix: update doc to uv first
* fix: merge dev/tests into uv pyproject
* fix: update docker docs to match current config
* fix: migrate examples to readme
* fix: add llmcompressor to conflict
* feat: rec uv sync with lockfile for dev/ci
* fix: update docker docs to clarify how to use uv images
* chore: docs
* fix: use system python, no venv
* fix: set backend cpu
* fix: only set for installing pytorch step
* fix: remove unsloth kernel and installs
* fix: remove U in tests
* fix: set backend in deps too
* chore: test
* chore: comments
* fix: attempt to lock torch
* fix: workaround torch cuda and not upgraded
* fix: forgot to push
* fix: missed source
* fix: nightly upstream loralinear config
* fix: nightly phi3 long rope not work
* fix: forgot commit
* fix: test phi3 template change
* fix: no more requirements
* fix: carry over changes from new requirements to pyproject
* chore: remove lockfile per discussion
* fix: set match-runtime
* fix: remove unneeded hf hub buildtime
* fix: duplicate cache delete on nightly
* fix: torchvision being overridden
* fix: migrate to uv images
* fix: leftover from merge
* fix: simplify base readme
* fix: update assertion message to be clearer
* chore: docs
* fix: change fallback for cicd script
* fix: match against main exactly
* fix: peft 0.19.1 change
* fix: e2e test
* fix: ci
* fix: e2e test
* bump transformers to 5.5.4 and trl to latest 1.1.0
* more upgrades
* update peft too
* adapt lora_merge to peft 0.19 layer config API
PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.
* allow lora_dropout on mixed attention+MoE configs under peft 0.19
PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).
Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
* fix: rename model to adapter_model for fsdp sharded final model
* fix: follow upstream transformer shard size
* fix: handle multiple model files
* fix redundant condition, tighten to safetensors, keep shard size small
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: support excess_length_strategy for RL trainers
Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.
- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
within sequence_len while preserving the full prompt, then decodes
back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len
Closes#3547
* improve RL truncation strategy robustness and performance
---------
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.
Made-with: Cursor
* Skip redundant evaluation when resuming from checkpoint
* add condition check for adding callback
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* better handling of dora merge on Conv layers in Qwen 3.5
* address issues from code review
* stricter efficient merges for dora since we now have meta model to reference
* qwen3_5.jinja: handle list content on system messages
The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.
Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.
Fixes#3590
* Address pr for other content types
---------
Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: remove unneeded debug log
* fix: cleanup
* feat: add dense gemma config and cleanup
* feat: add cce support
* update notes and set torch compile
* fix patch for new number of return vals
* fixes for gemma4
* fix packing bug
* use updated cce for mm
* fix: pass in kv cache func when avail for transformers 5.5
* feat: update examples with flex variant and readme
* gemma4 lora attention kernels
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* upgrade to torchao 0.17.0
* upgrade mistral-common too
* chore: lint
* patch fix for torchao low bit optimizers
* fix up
* propagate dtype
* fix test for ao change
* address PR comments
* feat: add sonicmoe fused lora support
* fix: forgot to add file
* feat: add test
* feat: add lora support for other routes
* fix: add int8 lora support
* fix: add qwen35_moe interleave support
* fix: qwen3_5_moe loss
* chore: lint
* address some pr comments
* fix test imports
* add support matrix for moe kernels [skip ci]
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* docs: comprehensive documentation improvements for humans and agents
New human docs:
- grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling)
- ebft.qmd: EBFT guide (structured/strided modes, feature extraction)
- choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO
- vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync)
- training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics
New agent docs:
- AGENTS_SFT.md: agent reference for supervised fine-tuning
- AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO)
Updated existing docs:
- rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides
- getting-started.qmd: reorganized Next Steps with links to new guides
- debugging.qmd: link to training stability guide
- _quarto.yml: added new pages to sidebar navigation
Removed:
- bak.agents.md: stale backup that confused agents
* docs: trim duplicated generic config from AGENTS_DPO.md
Remove boilerplate training params (optimizer, gradient_checkpointing,
flash_attention, etc.) from each method template. These are not
preference-learning-specific and are already covered in AGENTS_SFT.md.
Config templates now show only method-specific fields with a reference
to AGENTS_SFT.md for the rest.
* docs: deduplicate across new doc pages
- grpo.qmd: collapse vLLM setup section to brief config + link to
vllm_serving.qmd; collapse IS correction to essentials + link;
replace full monitoring tables with summary + link to
training_stability.qmd
- vllm_serving.qmd: remove duplicated async/IS config reference tables
(already in grpo.qmd config reference); replace full example config
with link to grpo.qmd quick start
- ebft.qmd: trim generic training params in quick start config
* fix: train scripts
* feat: split files into cleaner parts
* fix: cleanup pretraining docs
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>