Files
axolotl/docs/agents/sft.md
Wing Lian e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00

6.2 KiB

SFT — Agent Reference

Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see getting-started.qmd and dataset-formats/.

Architecture

YAML Config → axolotl train config.yaml

  1. Load base model (+ quantization if QLoRA/8-bit)
  2. Apply adapter layers (LoRA/QLoRA) if configured
  3. Load + tokenize dataset(s)
     - Apply prompt template (chat_template / alpaca / custom)
     - Mask inputs (train_on_inputs: false)
     - Pack samples into sequences (sample_packing: true)
  4. Training loop (HuggingFace Trainer)
     - forward → loss → backward → optimizer step → lr scheduler step
  5. Save model / adapter weights + tokenizer

Multi-GPU: FSDP or DeepSpeed shards model across GPUs automatically.

Components Required

  1. A YAML config — model, dataset(s), adapter settings, hyperparameters
  2. A dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path
  3. (Optional) A custom prompt strategy — for non-standard dataset formats

No external server processes needed (unlike GRPO which requires vLLM).

Dataset Format Decision Tree

Is your data in chat/message format?
  ├─ YES: OpenAI message format (role/content)?
  │   ├─ YES ──────────────────────> type: chat_template  (recommended)
  │   └─ NO (custom field names) ──> type: chat_template + message_property_mappings
  └─ NO: Instruction/response pairs?
      ├─ YES ──> type: alpaca       (instruction, input, output)
      └─ NO: Raw text?
          ├─ YES with segments ─────> type: input_output  (template-free masking)
          └─ YES continuous ────────> type: completion     (pretraining-style)

Full format specs: dataset-formats/

Model Size to Adapter Choice

Model Size LoRA QLoRA (4-bit) Full Fine-Tune VRAM (approx)
1-3B Preferred Low-budget option Single GPU OK 8-16 GB (LoRA)
7-8B Preferred Good balance Needs multi-GPU 16-24 GB (LoRA)
13-14B Preferred Good balance Multi-GPU required 24-40 GB (LoRA)
30-70B LoRA or QLoRA Preferred for single GPU Multi-node 40-80 GB (QLoRA)

Hyperparameter Ranges

Parameter LoRA QLoRA Full FT
learning_rate 1e-4 to 3e-4 1e-4 to 3e-4 1e-5 to 5e-5
lora_r 16-64 16-64 N/A
lora_alpha 1-2x lora_r 1-2x lora_r N/A
micro_batch_size 2-8 2-4 1-2
gradient_accumulation_steps 2-8 4-16 4-16
num_epochs 1-3 1-3 1-3
optimizer adamw_8bit adamw_bnb_8bit adamw_torch_fused

Effective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.

Healthy Training Indicators

Metric Healthy Problem
train_loss Decreasing, starting ~2-4 for chat models Flat or increasing from step 1 — data or LR issue
eval_loss Decreasing, tracks train_loss Increasing while train_loss decreases — overfitting
grad_norm 0.1-10, relatively stable Spikes >100 — instability. 0.0 — frozen weights
learning_rate Follows scheduler curve Flat or NaN — config issue

Watch for: loss never decreasing (check train_on_inputs, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See training_stability.qmd.

Known Issues

Issue Fix
OOM during training Reduce micro_batch_size, enable gradient_checkpointing, reduce sequence_len
sample_packing + SDPA + bf16 = 0.0 loss Use attn_implementation: flash_attention_2 or disable sample_packing
Missing chat template error Set chat_template: chatml explicitly
Label masking wrong Run axolotl preprocess config.yaml --debug and inspect labels
Loss NaN Use bf16: auto, lower LR, check data for empty samples
Tokenizer pad token / infinite loss Set special_tokens: pad_token: "<|end_of_text|>"
FSDP save hangs Use fsdp_state_dict_type: FULL_STATE_DICT
DeepSpeed CheckpointError Set use_reentrant: true in gradient_checkpointing_kwargs

Profiling

To profile training and identify optimization opportunities:

# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5

This produces profiler_trace.json (Chrome trace) and snapshot.pickle (memory snapshot) in output_dir. View the Chrome trace at chrome://tracing.

To programmatically inspect the trace:

python scripts/analyze_profile.py output_dir/

The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:

  • Large matmul kernels: candidates for fusion or quantization
  • Memory copies (H2D/D2H): unnecessary data movement
  • Small frequent kernels: candidates for kernel fusion
  • Gaps between kernels: pipeline bubbles from CPU overhead

Full troubleshooting: training_stability.qmd, debugging.qmd

File Map

src/axolotl/
  cli/train.py                     # Entry point for `axolotl train`
  cli/preprocess.py                # Entry point for `axolotl preprocess`
  core/builders/causal.py          # HFCausalTrainerBuilder — wires config → SFT trainer
  core/trainers/base.py            # AxolotlTrainer — base trainer class
  core/trainers/mixins/            # Packing, optimizer, scheduler, checkpoints
  prompt_strategies/               # Format handlers: chat_template, alpaca, completion, input_output
  utils/schemas/config.py          # AxolotlInputConfig — main config schema
  utils/schemas/datasets.py        # SFTDataset, DatasetConfig
  utils/schemas/peft.py            # LoraConfig — LoRA parameters
  integrations/liger/              # Liger kernel plugin

examples/llama-3/                  # LoRA, QLoRA, full FT example configs
docs/getting-started.qmd           # Quickstart with config templates
docs/optimizations.qmd             # Flash attention, gradient checkpointing, sample packing
docs/multi-gpu.qmd                 # FSDP and DeepSpeed setup