* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
140 lines
6.2 KiB
Markdown
140 lines
6.2 KiB
Markdown
# SFT — Agent Reference
|
|
|
|
Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see [getting-started.qmd](../getting-started.qmd) and [dataset-formats/](../dataset-formats/).
|
|
|
|
## Architecture
|
|
|
|
```
|
|
YAML Config → axolotl train config.yaml
|
|
|
|
1. Load base model (+ quantization if QLoRA/8-bit)
|
|
2. Apply adapter layers (LoRA/QLoRA) if configured
|
|
3. Load + tokenize dataset(s)
|
|
- Apply prompt template (chat_template / alpaca / custom)
|
|
- Mask inputs (train_on_inputs: false)
|
|
- Pack samples into sequences (sample_packing: true)
|
|
4. Training loop (HuggingFace Trainer)
|
|
- forward → loss → backward → optimizer step → lr scheduler step
|
|
5. Save model / adapter weights + tokenizer
|
|
|
|
Multi-GPU: FSDP or DeepSpeed shards model across GPUs automatically.
|
|
```
|
|
|
|
## Components Required
|
|
|
|
1. A YAML config — model, dataset(s), adapter settings, hyperparameters
|
|
2. A dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path
|
|
3. (Optional) A custom prompt strategy — for non-standard dataset formats
|
|
|
|
No external server processes needed (unlike GRPO which requires vLLM).
|
|
|
|
## Dataset Format Decision Tree
|
|
|
|
```
|
|
Is your data in chat/message format?
|
|
├─ YES: OpenAI message format (role/content)?
|
|
│ ├─ YES ──────────────────────> type: chat_template (recommended)
|
|
│ └─ NO (custom field names) ──> type: chat_template + message_property_mappings
|
|
└─ NO: Instruction/response pairs?
|
|
├─ YES ──> type: alpaca (instruction, input, output)
|
|
└─ NO: Raw text?
|
|
├─ YES with segments ─────> type: input_output (template-free masking)
|
|
└─ YES continuous ────────> type: completion (pretraining-style)
|
|
```
|
|
|
|
Full format specs: [dataset-formats/](../dataset-formats/)
|
|
|
|
## Model Size to Adapter Choice
|
|
|
|
| Model Size | LoRA | QLoRA (4-bit) | Full Fine-Tune | VRAM (approx) |
|
|
|-----------|------|---------------|----------------|---------------|
|
|
| 1-3B | Preferred | Low-budget option | Single GPU OK | 8-16 GB (LoRA) |
|
|
| 7-8B | Preferred | Good balance | Needs multi-GPU | 16-24 GB (LoRA) |
|
|
| 13-14B | Preferred | Good balance | Multi-GPU required | 24-40 GB (LoRA) |
|
|
| 30-70B | LoRA or QLoRA | Preferred for single GPU | Multi-node | 40-80 GB (QLoRA) |
|
|
|
|
## Hyperparameter Ranges
|
|
|
|
| Parameter | LoRA | QLoRA | Full FT |
|
|
|-----------|------|-------|---------|
|
|
| `learning_rate` | 1e-4 to 3e-4 | 1e-4 to 3e-4 | 1e-5 to 5e-5 |
|
|
| `lora_r` | 16-64 | 16-64 | N/A |
|
|
| `lora_alpha` | 1-2x `lora_r` | 1-2x `lora_r` | N/A |
|
|
| `micro_batch_size` | 2-8 | 2-4 | 1-2 |
|
|
| `gradient_accumulation_steps` | 2-8 | 4-16 | 4-16 |
|
|
| `num_epochs` | 1-3 | 1-3 | 1-3 |
|
|
| `optimizer` | `adamw_8bit` | `adamw_bnb_8bit` | `adamw_torch_fused` |
|
|
|
|
Effective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.
|
|
|
|
## Healthy Training Indicators
|
|
|
|
| Metric | Healthy | Problem |
|
|
|--------|---------|---------|
|
|
| `train_loss` | Decreasing, starting ~2-4 for chat models | Flat or increasing from step 1 — data or LR issue |
|
|
| `eval_loss` | Decreasing, tracks train_loss | Increasing while train_loss decreases — overfitting |
|
|
| `grad_norm` | 0.1-10, relatively stable | Spikes >100 — instability. 0.0 — frozen weights |
|
|
| `learning_rate` | Follows scheduler curve | Flat or NaN — config issue |
|
|
|
|
Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See [training_stability.qmd](../training_stability.qmd).
|
|
|
|
## Known Issues
|
|
|
|
| Issue | Fix |
|
|
|-------|-----|
|
|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
|
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
|
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
|
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
|
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
|
| Tokenizer pad token / infinite loss | Set `special_tokens: pad_token: "<\|end_of_text\|>"` |
|
|
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
|
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
|
|
|
## Profiling
|
|
|
|
To profile training and identify optimization opportunities:
|
|
|
|
```yaml
|
|
# Profile steps 3-7 (after warmup/autotuning settles)
|
|
profiler_steps_start: 3
|
|
profiler_steps: 5
|
|
```
|
|
|
|
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
|
View the Chrome trace at `chrome://tracing`.
|
|
|
|
To programmatically inspect the trace:
|
|
```bash
|
|
python scripts/analyze_profile.py output_dir/
|
|
```
|
|
|
|
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
|
- **Large matmul kernels**: candidates for fusion or quantization
|
|
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
|
- **Small frequent kernels**: candidates for kernel fusion
|
|
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
|
|
|
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
|
|
|
## File Map
|
|
|
|
```
|
|
src/axolotl/
|
|
cli/train.py # Entry point for `axolotl train`
|
|
cli/preprocess.py # Entry point for `axolotl preprocess`
|
|
core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer
|
|
core/trainers/base.py # AxolotlTrainer — base trainer class
|
|
core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints
|
|
prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output
|
|
utils/schemas/config.py # AxolotlInputConfig — main config schema
|
|
utils/schemas/datasets.py # SFTDataset, DatasetConfig
|
|
utils/schemas/peft.py # LoraConfig — LoRA parameters
|
|
integrations/liger/ # Liger kernel plugin
|
|
|
|
examples/llama-3/ # LoRA, QLoRA, full FT example configs
|
|
docs/getting-started.qmd # Quickstart with config templates
|
|
docs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing
|
|
docs/multi-gpu.qmd # FSDP and DeepSpeed setup
|
|
```
|