* nemo gym integration with grpo wip * mostly working * cleanup * simplify * update docs * nemo gym support wip * cleanup * chore: lint * address PR review and add more tests * chore: lint * post merge lora fixes for CI (#3536) [skip ci] * post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation * address pr comments, add timeout, add tests * roundup_power2_divisions not needed with newer pytorch versions (#3540) * roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg * fix: robust handling of race condition on patching check (#3543) [skip ci] * EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict * fix for ebft sync and update docs * make trainer loss patch check a solo test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Energy-Based Fine-Tuning (EBFT)
EBFT is an integration of "Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models" (Jelassi et al., 2026) into axolotl.
Overview
EBFT fine-tunes language models by optimizing a feature-matching loss rather than relying on external reward functions or verifiers. A frozen copy of the model (the "feature network") extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
Key advantages over SFT:
- Operates on model rollouts (not teacher forcing), reducing distribution shift
- Provides dense sequence-level supervision without a task-specific verifier
- Improves both downstream accuracy and validation cross-entropy simultaneously
Key advantages over RLVR:
- No reward model or verifier required — works on any (prompt, completion) data
- Applicable to non-verifiable tasks (e.g., raw code, translation, creative writing)
- Maintains distributional calibration (low feature-matching loss)
Two Modes
EBFT supports two modes depending on your data format:
Structured Mode (mode: structured, default)
For QA/instruction data with prompt + completion pairs (e.g., OpenCodeInstruct, ALMA translation).
- Extends GRPOTrainer — uses vLLM for fast rollout generation
- RLOO advantages and clipped policy gradient from GRPO
- Feature-matching rewards replace external reward functions
Strided Mode (mode: strided)
For unstructured text without prompt/completion splits (e.g., raw code, prose, SwallowCode).
- Uses strided block-parallel generation — multiple short rollouts at different anchor points within a document
- No vLLM needed — generation uses custom strided attention masks
- Uses torch flex_attention with compiled block masks for efficient fused attention kernels (~2x faster than eager attention)
- Compatible with gradient checkpointing via automatic dtype normalization
- This is the core EBFT algorithm from the paper (Section F)
Common to both modes:
- Frozen feature network — deep copy of the model at initialization (frozen, eval mode)
- Feature extraction — hidden states at configurable layer depths (default: 25%, 50%, 75%), L2-normalized per layer before concatenation
- Feature-matching rewards — cosine similarity (alignment) minus pairwise dot-product (diversity), scaled by 2 per paper equation (7)
- SVD whitening — decorrelates feature dimensions; the paper shows removing it causes the largest degradation
- CFM loss tracking — conditional feature-matching loss (paper eq 2) logged as
ebft/cfm_loss - FSDP2 compatible — feature network stays outside FSDP wrapping (frozen, inference-only)
Quick Start
Structured Mode (QA data + vLLM)
# 1. Start vLLM server (LoRA serve module auto-selected when vllm_lora_sync: true)
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve examples/ebft/qwen3-4b-ebft-structured-async.yaml
# 2. Train on a separate GPU
CUDA_VISIBLE_DEVICES=1 axolotl train examples/ebft/qwen3-4b-ebft-structured-async.yaml
Strided Mode (unstructured text)
# No vLLM needed — strided generation is built-in
axolotl train examples/ebft/llama-3b-ebft-strided-fft.yaml
Configuration
Common EBFT Settings
rl: ebft
ebft:
# Feature network: which layers to extract hidden states from
# Values are fractions of total depth (0.0 = embedding, 1.0 = final layer)
feature_layers: [0.25, 0.5, 0.75]
# How to pool per-token hidden states into sequence embeddings
# Options: "last_token" (recommended), "mean_pooling", "concat"
embed_method: last_token
# SVD whitening — strongly recommended (paper shows largest degradation without it)
use_whitening: true
# Reward = alignment_coef * alignment - diversity_coef * diversity
# Per paper Variant (i) (eq 49): alignment uses cosine similarity (normalized),
# diversity uses raw dot product — both are bounded after whitening.
alignment_coef: 1.0
diversity_coef: 1.0
# Cross-entropy loss on ground-truth tokens (mixed objective, paper Section 2.1)
# 0.0 = pure feature matching; 0.03 = recommended balance; 0.1 = CE-dominated
ce_coef: 0.0
Strided Mode Settings
ebft:
mode: strided
stride: 8 # tokens between anchor points (paper default: 8)
context_length: 8 # context window per block (paper default: 8)
generate_max_len: 8 # tokens generated per block (paper default: 8)
n_samples_per_prompt: 4 # independent rollouts per document (>= 2 for RLOO)
temperature: 0.6
rl_coef: 1.0 # RL loss weight
advantage_estimator: rloo # rloo (recommended), group_norm, or reinforce
Structured Mode Settings (via TRL)
trl:
num_generations: 4 # samples per prompt
max_completion_length: 256 # max tokens to generate
temperature: 1.0
use_vllm: true
scale_rewards: true
loss_type: grpo
epsilon: 0.2
Dataset Format
Structured mode — QA data with prompt + ground-truth completion:
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
Transform returns: {"prompt": ..., "ground_truth": ...}
Strided mode — raw text tokenized to fixed length:
datasets:
- path: sjelassi/swallow_code_20m
type: ebft_pretrain.transform
Transform returns: {"input_ids": ..., "attention_mask": ..., "labels": ...}
How It Works
Structured Mode
- Generate: For each prompt, generate
num_generationscompletions via vLLM - Extract features: Forward both generated and ground-truth sequences through the frozen feature network
- Compute rewards:
2 * alignment - 2 * diversity(paper eq 7) - RLOO advantages: subtract leave-one-out group mean
- Policy gradient: clipped PPO-style loss
Strided Mode
- Anchor selection: Pick
num_blocks = (seq_len - gen_len - ctx_len) / stride + 1anchor points across the document - Block-parallel generation: At each anchor, generate
gen_lentokens using a custom strided attention mask viaflex_attentioncompiled block masks - Feature extraction: Forward the full sequence (prompt + generated) through the frozen feature network with the strided attention mask — this is critical for correct feature representations
- Per-block rewards:
- Alignment =
2 * cosine_similarity(gen_block_emb, gt_block_emb)— normalized, bounded in [-2, 2] - Diversity =
2 * mean_pairwise_dot_product(gen_block_embs)— raw dot product on whitened vectors - Reward =
alignment_coef * alignment - diversity_coef * diversity
- Alignment =
- RLOO advantages: leave-one-out baseline across
n_samples_per_promptrollouts per block - Policy gradient: REINFORCE loss on generated tokens, weighted by per-block advantages
Tracked Metrics
| Metric | Description |
|---|---|
ebft/alignment |
Mean cosine similarity between generated and GT features (higher = better) |
ebft/diversity |
Mean pairwise similarity between samples (lower = more diverse) |
ebft/mean_reward |
alignment - diversity (should trend upward) |
ebft/cfm_loss |
Conditional feature-matching loss ‖E[φ(ŷ)] - φ(y)‖² (paper eq 2, lower = better) |
ebft/rl_loss |
REINFORCE policy gradient loss |
ebft/ce_loss |
Cross-entropy loss on ground-truth tokens (when ce_coef > 0) |
ebft/advantages_std |
RLOO advantage standard deviation (should be non-zero) |
Tips and Recommendations
Reward coefficients
use_whitening: true: Strongly recommended. The paper's ablation (Figure 7) shows removing whitening causes the largest performance degradation. Safe to use withdiversity_coef > 0.diversity_coef: Default 1.0. Per the paper's Variant (i) (eq 49), alignment uses cosine similarity while diversity uses raw dot product. After whitening, both are bounded and on compatible scales.n_samples_per_prompt: Must be >= 2 for diversity and RLOO. 4 is the paper's default.ce_coef: The paper ablatesγ ∈ {0, 0.03, 0.1}.0.03balances CE and RL signals;0.1causes CE to dominate the gradient.0.0gives pure feature matching.
Feature extraction
feature_layers: [0.25, 0.5, 0.75]: Extracts and concatenates hidden states from 25%, 50%, 75% depth. Each layer is L2-normalized independently before concatenation. The paper shows this works better than mean pooling or single-layer extraction.embed_method: last_token: Uses the last token's hidden state per block. The paper shows this outperforms mean pooling (Figure 7).
Performance
torch_compile: true: Recommended for strided mode. Provides additional speedup via graph compilation.- flex_attention: Strided mode automatically uses
flex_attentionwith compiled block masks when available (~2x faster than eager attention). Works with gradient checkpointing via automatic dtype normalization. Falls back to eager attention with dense 4D masks if flex_attention is unavailable.
Memory
- EBFT requires a frozen copy of the model (the feature network), roughly doubling model memory.
- LoRA is recommended to reduce trainable parameter memory. The feature network is always a frozen copy of the base model (without LoRA adapters).
- With 2 GPUs visible, the trainer automatically places the feature network on the second GPU.
- FSDP2 is supported — the feature network stays outside FSDP wrapping since it's frozen and inference-only. With
cpu_ram_efficient_loading, the feature network is loaded separately from pretrained weights.
Example Configs
| Config | Mode | Model | Description |
|---|---|---|---|
llama-1b-ebft-opencode.yaml |
Structured | Llama-3.2-1B | QA coding with vLLM |
llama-1b-ebft-opencode-novllm.yaml |
Structured | Llama-3.2-1B | QA coding without vLLM |
llama-3b-ebft-strided-fft.yaml |
Strided | Llama-3.2-3B | Unstructured code with LoRA |
llama-1b-ebft-strided.yaml |
Strided | Llama-3.2-1B | Quick validation |
Citation
@article{jelassi2026matching,
title={Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models},
author={Jelassi, Samy and Kwun, Mujin and Zhao, Rosie and Li, Yuanzhi and Fusi, Nicolo and Du, Yilun and Kakade, Sham M. and Domingo-Enrich, Carles},
journal={arXiv preprint arXiv:2603.12248},
year={2026}
}