Files
Wing Lian e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
..

Energy-Based Fine-Tuning (EBFT)

EBFT is an integration of "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models" (Jelassi et al., 2026) into axolotl.

Overview

EBFT fine-tunes language models by optimizing a feature-matching loss rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.

Key advantages over SFT:

  • Operates on model rollouts (not teacher forcing), reducing distribution shift
  • Provides dense sequence-level supervision without a task-specific verifier
  • Improves both downstream accuracy and validation cross-entropy simultaneously

Key advantages over RLVR:

  • No reward model or verifier required — works on any (prompt, completion) data
  • Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
  • Maintains distributional calibration (low feature-matching loss)

Two Modes

EBFT supports two modes depending on your data format:

Structured Mode (mode: structured, default)

For QA/instruction data with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).

  • Extends GRPOTrainer — uses vLLM for fast rollout generation
  • RLOO advantages and clipped policy gradient from GRPO
  • Feature-matching rewards replace external reward functions

Strided Mode (mode: strided)

For unstructured text without prompt/completion splits (e.g., raw code, prose, SwallowCode).

  • Uses strided block-parallel generation — multiple short rollouts at different anchor points within a document
  • No vLLM needed — generation uses custom strided attention masks
  • Uses torch flex_attention with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
  • Compatible with gradient checkpointing via automatic dtype normalization
  • This is the core EBFT algorithm from the paper (Section F)

Common to both modes:

  • Frozen feature network — deep copy of the model at initialization (frozen, eval mode)
  • Feature extraction — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
  • Feature-matching rewards — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
  • SVD whitening — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
  • CFM loss tracking — conditional feature-matching loss (paper eq 2) logged as ebft/cfm_loss
  • FSDP2 compatible — feature network stays outside FSDP wrapping (frozen, inference-only)

Quick Start

Structured Mode (QA data + vLLM)

# 1. Start vLLM server (LoRA serve module auto-selected when vllm_lora_sync: true)
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen3-4b-ebft-structured-async.yaml

# 2. Train on a separate GPU
CUDA_VISIBLE_DEVICES=1 axolotl train examples/ebft/qwen3-4b-ebft-structured-async.yaml

Strided Mode (unstructured text)

# No vLLM needed — strided generation is built-in
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml

Configuration

Common EBFT Settings

rl: ebft

ebft:
  # Feature network: which layers to extract hidden states from
  # Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
  feature_layers: [0.25, 0.5, 0.75]

  # How to pool per-token hidden states into sequence embeddings
  # Options: "last_token" (recommended), "mean_pooling", "concat"
  embed_method: last_token

  # SVD whitening — strongly recommended (paper shows largest degradation without it)
  use_whitening: true

  # Reward = alignment_coef * alignment - diversity_coef * diversity
  # Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
  # diversity uses raw dot product — both are bounded after whitening.
  alignment_coef: 1.0
  diversity_coef: 1.0

  # Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
  # 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
  ce_coef: 0.0

Strided Mode Settings

ebft:
  mode: strided
  stride: 8              # tokens between anchor points (paper default: 8)
  context_length: 8      # context window per block (paper default: 8)
  generate_max_len: 8    # tokens generated per block (paper default: 8)
  n_samples_per_prompt: 4  # independent rollouts per document (>= 2 for RLOO)
  temperature: 0.6
  rl_coef: 1.0           # RL loss weight
  advantage_estimator: rloo  # rloo (recommended), group_norm, or reinforce

Structured Mode Settings (via TRL)

trl:
  num_generations: 4           # samples per prompt
  max_completion_length: 256   # max tokens to generate
  temperature: 1.0
  use_vllm: true
  scale_rewards: true
  loss_type: grpo
  epsilon: 0.2

Dataset Format

Structured mode — QA data with prompt + ground-truth completion:

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_opencode.transform

Transform returns: {"prompt": ..., "ground_truth": ...}

Strided mode — raw text tokenized to fixed length:

datasets:
  - path: sjelassi/swallow_code_20m
    type: ebft_pretrain.transform

Transform returns: {"input_ids": ..., "attention_mask": ..., "labels": ...}

How It Works

Structured Mode

  1. Generate: For each prompt, generate num_generations completions via vLLM
  2. Extract features: Forward both generated and ground-truth sequences through the frozen feature network
  3. Compute rewards: 2 * alignment - 2 * diversity (paper eq 7)
  4. RLOO advantages: subtract leave-one-out group mean
  5. Policy gradient: clipped PPO-style loss

Strided Mode

  1. Anchor selection: Pick num_blocks = (seq_len - gen_len - ctx_len) / stride + 1 anchor points across the document
  2. Block-parallel generation: At each anchor, generate gen_len tokens using a custom strided attention mask via flex_attention compiled block masks
  3. Feature extraction: Forward the full sequence (prompt + generated) through the frozen feature network with the strided attention mask — this is critical for correct feature representations
  4. Per-block rewards:
    • Alignment = 2 * cosine_similarity(gen_block_emb, gt_block_emb) — normalized, bounded in [-2, 2]
    • Diversity = 2 * mean_pairwise_dot_product(gen_block_embs) — raw dot product on whitened vectors
    • Reward = alignment_coef * alignment - diversity_coef * diversity
  5. RLOO advantages: leave-one-out baseline across n_samples_per_prompt rollouts per block
  6. Policy gradient: REINFORCE loss on generated tokens, weighted by per-block advantages

Tracked Metrics

Metric Description
ebft/alignment Mean cosine similarity between generated and GT features (higher = better)
ebft/diversity Mean pairwise similarity between samples (lower = more diverse)
ebft/mean_reward alignment - diversity (should trend upward)
ebft/cfm_loss Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better)
ebft/rl_loss REINFORCE policy gradient loss
ebft/ce_loss Cross-entropy loss on ground-truth tokens (when ce_coef > 0)
ebft/advantages_std RLOO advantage standard deviation (should be non-zero)

Tips and Recommendations

Reward coefficients

  • use_whitening: true: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use with diversity_coef > 0.
  • diversity_coef: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.
  • n_samples_per_prompt: Must be >= 2 for diversity and RLOO. 4 is the paper's default.
  • ce_coef: The paper ablates γ ∈ {0, 0.03, 0.1}. 0.03 balances CE and RL signals; 0.1 causes CE to dominate the gradient. 0.0 gives pure feature matching.

Feature extraction

  • feature_layers: [0.25, 0.5, 0.75]: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.
  • embed_method: last_token: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).

Performance

  • torch_compile: true: Recommended for strided mode. Provides additional speedup via graph compilation.
  • flex_attention: Strided mode automatically uses flex_attention with compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.

Memory

  • EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
  • LoRA is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
  • With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
  • FSDP2 is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With cpu_ram_efficient_loading, the feature network is loaded separately from pretrained weights.

Example Configs

Config Mode Model Description
llama-1b-ebft-opencode.yaml Structured Llama-3.2-1B QA coding with vLLM
llama-1b-ebft-opencode-novllm.yaml Structured Llama-3.2-1B QA coding without vLLM
llama-3b-ebft-strided-fft.yaml Strided Llama-3.2-3B Unstructured code with LoRA
llama-1b-ebft-strided.yaml Strided Llama-3.2-1B Quick validation

Citation

@article{jelassi2026matching,
  title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
  author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
  journal={arXiv preprint arXiv:2603.12248},
  year={2026}
}